mirror of
https://gitlab.com/MoonTestUse1/AdministrationItDepartmens.git
synced 2025-08-14 00:25:46 +02:00
Проверка 09.02.2025
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""FastAPI framework, high performance, easy to learn, fast to code, ready for production"""
|
||||
|
||||
__version__ = "0.110.0"
|
||||
__version__ = "0.115.8"
|
||||
|
||||
from starlette import status as status
|
||||
|
||||
|
@@ -2,6 +2,7 @@ from collections import deque
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@@ -24,7 +25,8 @@ from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from starlette.datastructures import UploadFile
|
||||
from typing_extensions import Annotated, Literal, get_args, get_origin
|
||||
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||||
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
|
||||
|
||||
|
||||
sequence_annotation_to_type = {
|
||||
@@ -43,6 +45,8 @@ sequence_annotation_to_type = {
|
||||
|
||||
sequence_types = tuple(sequence_annotation_to_type.keys())
|
||||
|
||||
Url: Type[Any]
|
||||
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
|
||||
from pydantic import TypeAdapter
|
||||
@@ -68,7 +72,7 @@ if PYDANTIC_V2:
|
||||
general_plain_validator_function as with_info_plain_validator_function, # noqa: F401
|
||||
)
|
||||
|
||||
Required = PydanticUndefined
|
||||
RequiredParam = PydanticUndefined
|
||||
Undefined = PydanticUndefined
|
||||
UndefinedType = PydanticUndefinedType
|
||||
evaluate_forwardref = eval_type_lenient
|
||||
@@ -127,7 +131,7 @@ if PYDANTIC_V2:
|
||||
)
|
||||
except ValidationError as exc:
|
||||
return None, _regenerate_error_with_loc(
|
||||
errors=exc.errors(), loc_prefix=loc
|
||||
errors=exc.errors(include_url=False), loc_prefix=loc
|
||||
)
|
||||
|
||||
def serialize(
|
||||
@@ -266,7 +270,7 @@ if PYDANTIC_V2:
|
||||
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
|
||||
error = ValidationError.from_exception_data(
|
||||
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
|
||||
).errors()[0]
|
||||
).errors(include_url=False)[0]
|
||||
error["input"] = None
|
||||
return error # type: ignore[return-value]
|
||||
|
||||
@@ -277,6 +281,12 @@ if PYDANTIC_V2:
|
||||
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
|
||||
return BodyModel
|
||||
|
||||
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||
return [
|
||||
ModelField(field_info=field_info, name=name)
|
||||
for name, field_info in model.model_fields.items()
|
||||
]
|
||||
|
||||
else:
|
||||
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
|
||||
from pydantic import AnyUrl as Url # noqa: F401
|
||||
@@ -304,9 +314,10 @@ else:
|
||||
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
|
||||
ModelField as ModelField, # noqa: F401
|
||||
)
|
||||
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
|
||||
Required as Required, # noqa: F401
|
||||
)
|
||||
|
||||
# Keeping old "Required" functionality from Pydantic V1, without
|
||||
# shadowing typing.Required.
|
||||
RequiredParam: Any = Ellipsis # type: ignore[no-redef]
|
||||
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
|
||||
Undefined as Undefined,
|
||||
)
|
||||
@@ -511,6 +522,9 @@ else:
|
||||
BodyModel.__fields__[f.name] = f # type: ignore[index]
|
||||
return BodyModel
|
||||
|
||||
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||
return list(model.__fields__.values()) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _regenerate_error_with_loc(
|
||||
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
|
||||
@@ -530,6 +544,12 @@ def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||||
|
||||
|
||||
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
for arg in get_args(annotation):
|
||||
if field_annotation_is_sequence(arg):
|
||||
return True
|
||||
return False
|
||||
return _annotation_is_sequence(annotation) or _annotation_is_sequence(
|
||||
get_origin(annotation)
|
||||
)
|
||||
@@ -632,3 +652,8 @@ def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
|
||||
is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||
return get_model_fields(model)
|
||||
|
@@ -40,7 +40,7 @@ from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse, JSONResponse, Response
|
||||
from starlette.routing import BaseRoute
|
||||
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
|
||||
from typing_extensions import Annotated, Doc, deprecated # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated, Doc, deprecated
|
||||
|
||||
AppType = TypeVar("AppType", bound="FastAPI")
|
||||
|
||||
@@ -902,7 +902,7 @@ class FastAPI(Starlette):
|
||||
A state object for the application. This is the same object for the
|
||||
entire application, it doesn't change from request to request.
|
||||
|
||||
You normally woudln't use this in FastAPI, for most of the cases you
|
||||
You normally wouldn't use this in FastAPI, for most of the cases you
|
||||
would instead use FastAPI dependencies.
|
||||
|
||||
This is simply inherited from Starlette.
|
||||
@@ -1019,7 +1019,7 @@ class FastAPI(Starlette):
|
||||
oauth2_redirect_url = root_path + oauth2_redirect_url
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=openapi_url,
|
||||
title=self.title + " - Swagger UI",
|
||||
title=f"{self.title} - Swagger UI",
|
||||
oauth2_redirect_url=oauth2_redirect_url,
|
||||
init_oauth=self.swagger_ui_init_oauth,
|
||||
swagger_ui_parameters=self.swagger_ui_parameters,
|
||||
@@ -1043,7 +1043,7 @@ class FastAPI(Starlette):
|
||||
root_path = req.scope.get("root_path", "").rstrip("/")
|
||||
openapi_url = root_path + self.openapi_url
|
||||
return get_redoc_html(
|
||||
openapi_url=openapi_url, title=self.title + " - ReDoc"
|
||||
openapi_url=openapi_url, title=f"{self.title} - ReDoc"
|
||||
)
|
||||
|
||||
self.add_route(self.redoc_url, redoc_html, include_in_schema=False)
|
||||
@@ -1056,7 +1056,7 @@ class FastAPI(Starlette):
|
||||
def add_api_route(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: Callable[..., Coroutine[Any, Any, Response]],
|
||||
endpoint: Callable[..., Any],
|
||||
*,
|
||||
response_model: Any = Default(None),
|
||||
status_code: Optional[int] = None,
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Callable
|
||||
|
||||
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
|
||||
from typing_extensions import Annotated, Doc, ParamSpec # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated, Doc, ParamSpec
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from contextlib import asynccontextmanager as asynccontextmanager
|
||||
from typing import AsyncGenerator, ContextManager, TypeVar
|
||||
|
||||
import anyio
|
||||
import anyio.to_thread
|
||||
from anyio import CapacityLimiter
|
||||
from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
|
||||
from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa
|
||||
@@ -28,7 +28,7 @@ async def contextmanager_in_threadpool(
|
||||
except Exception as e:
|
||||
ok = bool(
|
||||
await anyio.to_thread.run_sync(
|
||||
cm.__exit__, type(e), e, None, limiter=exit_limiter
|
||||
cm.__exit__, type(e), e, e.__traceback__, limiter=exit_limiter
|
||||
)
|
||||
)
|
||||
if not ok:
|
||||
|
@@ -24,7 +24,7 @@ from starlette.datastructures import Headers as Headers # noqa: F401
|
||||
from starlette.datastructures import QueryParams as QueryParams # noqa: F401
|
||||
from starlette.datastructures import State as State # noqa: F401
|
||||
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated, Doc
|
||||
|
||||
|
||||
class UploadFile(StarletteUploadFile):
|
||||
|
@@ -1,58 +1,37 @@
|
||||
from typing import Any, Callable, List, Optional, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple
|
||||
|
||||
from fastapi._compat import ModelField
|
||||
from fastapi.security.base import SecurityBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityRequirement:
|
||||
def __init__(
|
||||
self, security_scheme: SecurityBase, scopes: Optional[Sequence[str]] = None
|
||||
):
|
||||
self.security_scheme = security_scheme
|
||||
self.scopes = scopes
|
||||
security_scheme: SecurityBase
|
||||
scopes: Optional[Sequence[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Dependant:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
path_params: Optional[List[ModelField]] = None,
|
||||
query_params: Optional[List[ModelField]] = None,
|
||||
header_params: Optional[List[ModelField]] = None,
|
||||
cookie_params: Optional[List[ModelField]] = None,
|
||||
body_params: Optional[List[ModelField]] = None,
|
||||
dependencies: Optional[List["Dependant"]] = None,
|
||||
security_schemes: Optional[List[SecurityRequirement]] = None,
|
||||
name: Optional[str] = None,
|
||||
call: Optional[Callable[..., Any]] = None,
|
||||
request_param_name: Optional[str] = None,
|
||||
websocket_param_name: Optional[str] = None,
|
||||
http_connection_param_name: Optional[str] = None,
|
||||
response_param_name: Optional[str] = None,
|
||||
background_tasks_param_name: Optional[str] = None,
|
||||
security_scopes_param_name: Optional[str] = None,
|
||||
security_scopes: Optional[List[str]] = None,
|
||||
use_cache: bool = True,
|
||||
path: Optional[str] = None,
|
||||
) -> None:
|
||||
self.path_params = path_params or []
|
||||
self.query_params = query_params or []
|
||||
self.header_params = header_params or []
|
||||
self.cookie_params = cookie_params or []
|
||||
self.body_params = body_params or []
|
||||
self.dependencies = dependencies or []
|
||||
self.security_requirements = security_schemes or []
|
||||
self.request_param_name = request_param_name
|
||||
self.websocket_param_name = websocket_param_name
|
||||
self.http_connection_param_name = http_connection_param_name
|
||||
self.response_param_name = response_param_name
|
||||
self.background_tasks_param_name = background_tasks_param_name
|
||||
self.security_scopes = security_scopes
|
||||
self.security_scopes_param_name = security_scopes_param_name
|
||||
self.name = name
|
||||
self.call = call
|
||||
self.use_cache = use_cache
|
||||
# Store the path to be able to re-generate a dependable from it in overrides
|
||||
self.path = path
|
||||
# Save the cache key at creation to optimize performance
|
||||
path_params: List[ModelField] = field(default_factory=list)
|
||||
query_params: List[ModelField] = field(default_factory=list)
|
||||
header_params: List[ModelField] = field(default_factory=list)
|
||||
cookie_params: List[ModelField] = field(default_factory=list)
|
||||
body_params: List[ModelField] = field(default_factory=list)
|
||||
dependencies: List["Dependant"] = field(default_factory=list)
|
||||
security_requirements: List[SecurityRequirement] = field(default_factory=list)
|
||||
name: Optional[str] = None
|
||||
call: Optional[Callable[..., Any]] = None
|
||||
request_param_name: Optional[str] = None
|
||||
websocket_param_name: Optional[str] = None
|
||||
http_connection_param_name: Optional[str] = None
|
||||
response_param_name: Optional[str] = None
|
||||
background_tasks_param_name: Optional[str] = None
|
||||
security_scopes_param_name: Optional[str] = None
|
||||
security_scopes: Optional[List[str]] = None
|
||||
use_cache: bool = True
|
||||
path: Optional[str] = None
|
||||
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import inspect
|
||||
from contextlib import AsyncExitStack, contextmanager
|
||||
from copy import deepcopy
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@@ -23,7 +24,7 @@ from fastapi._compat import (
|
||||
PYDANTIC_V2,
|
||||
ErrorWrapper,
|
||||
ModelField,
|
||||
Required,
|
||||
RequiredParam,
|
||||
Undefined,
|
||||
_regenerate_error_with_loc,
|
||||
copy_field_info,
|
||||
@@ -31,6 +32,7 @@ from fastapi._compat import (
|
||||
evaluate_forwardref,
|
||||
field_annotation_is_scalar,
|
||||
get_annotation_from_field_info,
|
||||
get_cached_model_fields,
|
||||
get_missing_field_error,
|
||||
is_bytes_field,
|
||||
is_bytes_sequence_field,
|
||||
@@ -54,11 +56,18 @@ from fastapi.logger import logger
|
||||
from fastapi.security.base import SecurityBase
|
||||
from fastapi.security.oauth2 import OAuth2, SecurityScopes
|
||||
from fastapi.security.open_id_connect_url import OpenIdConnect
|
||||
from fastapi.utils import create_response_field, get_path_param_names
|
||||
from fastapi.utils import create_model_field, get_path_param_names
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
|
||||
from starlette.datastructures import (
|
||||
FormData,
|
||||
Headers,
|
||||
ImmutableMultiDict,
|
||||
QueryParams,
|
||||
UploadFile,
|
||||
)
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette.responses import Response
|
||||
from starlette.websockets import WebSocket
|
||||
@@ -79,17 +88,23 @@ multipart_incorrect_install_error = (
|
||||
)
|
||||
|
||||
|
||||
def check_file_field(field: ModelField) -> None:
|
||||
field_info = field.field_info
|
||||
if isinstance(field_info, params.Form):
|
||||
def ensure_multipart_is_installed() -> None:
|
||||
try:
|
||||
from python_multipart import __version__
|
||||
|
||||
# Import an attribute that can be mocked/deleted in testing
|
||||
assert __version__ > "0.0.12"
|
||||
except (ImportError, AssertionError):
|
||||
try:
|
||||
# __version__ is available in both multiparts, and can be mocked
|
||||
from multipart import __version__ # type: ignore
|
||||
from multipart import __version__ # type: ignore[no-redef,import-untyped]
|
||||
|
||||
assert __version__
|
||||
try:
|
||||
# parse_options_header is only available in the right multipart
|
||||
from multipart.multipart import parse_options_header # type: ignore
|
||||
from multipart.multipart import ( # type: ignore[import-untyped]
|
||||
parse_options_header,
|
||||
)
|
||||
|
||||
assert parse_options_header
|
||||
except ImportError:
|
||||
@@ -175,7 +190,7 @@ def get_flat_dependant(
|
||||
header_params=dependant.header_params.copy(),
|
||||
cookie_params=dependant.cookie_params.copy(),
|
||||
body_params=dependant.body_params.copy(),
|
||||
security_schemes=dependant.security_requirements.copy(),
|
||||
security_requirements=dependant.security_requirements.copy(),
|
||||
use_cache=dependant.use_cache,
|
||||
path=dependant.path,
|
||||
)
|
||||
@@ -194,14 +209,23 @@ def get_flat_dependant(
|
||||
return flat_dependant
|
||||
|
||||
|
||||
def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]:
|
||||
if not fields:
|
||||
return fields
|
||||
first_field = fields[0]
|
||||
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
|
||||
fields_to_extract = get_cached_model_fields(first_field.type_)
|
||||
return fields_to_extract
|
||||
return fields
|
||||
|
||||
|
||||
def get_flat_params(dependant: Dependant) -> List[ModelField]:
|
||||
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
|
||||
return (
|
||||
flat_dependant.path_params
|
||||
+ flat_dependant.query_params
|
||||
+ flat_dependant.header_params
|
||||
+ flat_dependant.cookie_params
|
||||
)
|
||||
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
|
||||
query_params = _get_flat_fields_from_params(flat_dependant.query_params)
|
||||
header_params = _get_flat_fields_from_params(flat_dependant.header_params)
|
||||
cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
|
||||
return path_params + query_params + header_params + cookie_params
|
||||
|
||||
|
||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
@@ -258,16 +282,16 @@ def get_dependant(
|
||||
)
|
||||
for param_name, param in signature_params.items():
|
||||
is_path_param = param_name in path_param_names
|
||||
type_annotation, depends, param_field = analyze_param(
|
||||
param_details = analyze_param(
|
||||
param_name=param_name,
|
||||
annotation=param.annotation,
|
||||
value=param.default,
|
||||
is_path_param=is_path_param,
|
||||
)
|
||||
if depends is not None:
|
||||
if param_details.depends is not None:
|
||||
sub_dependant = get_param_sub_dependant(
|
||||
param_name=param_name,
|
||||
depends=depends,
|
||||
depends=param_details.depends,
|
||||
path=path,
|
||||
security_scopes=security_scopes,
|
||||
)
|
||||
@@ -275,18 +299,18 @@ def get_dependant(
|
||||
continue
|
||||
if add_non_field_param_to_dependency(
|
||||
param_name=param_name,
|
||||
type_annotation=type_annotation,
|
||||
type_annotation=param_details.type_annotation,
|
||||
dependant=dependant,
|
||||
):
|
||||
assert (
|
||||
param_field is None
|
||||
param_details.field is None
|
||||
), f"Cannot specify multiple FastAPI annotations for {param_name!r}"
|
||||
continue
|
||||
assert param_field is not None
|
||||
if is_body_param(param_field=param_field, is_path_param=is_path_param):
|
||||
dependant.body_params.append(param_field)
|
||||
assert param_details.field is not None
|
||||
if isinstance(param_details.field.field_info, params.Body):
|
||||
dependant.body_params.append(param_details.field)
|
||||
else:
|
||||
add_param_to_fields(field=param_field, dependant=dependant)
|
||||
add_param_to_fields(field=param_details.field, dependant=dependant)
|
||||
return dependant
|
||||
|
||||
|
||||
@@ -314,13 +338,20 @@ def add_non_field_param_to_dependency(
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamDetails:
|
||||
type_annotation: Any
|
||||
depends: Optional[params.Depends]
|
||||
field: Optional[ModelField]
|
||||
|
||||
|
||||
def analyze_param(
|
||||
*,
|
||||
param_name: str,
|
||||
annotation: Any,
|
||||
value: Any,
|
||||
is_path_param: bool,
|
||||
) -> Tuple[Any, Optional[params.Depends], Optional[ModelField]]:
|
||||
) -> ParamDetails:
|
||||
field_info = None
|
||||
depends = None
|
||||
type_annotation: Any = Any
|
||||
@@ -328,6 +359,7 @@ def analyze_param(
|
||||
if annotation is not inspect.Signature.empty:
|
||||
use_annotation = annotation
|
||||
type_annotation = annotation
|
||||
# Extract Annotated info
|
||||
if get_origin(use_annotation) is Annotated:
|
||||
annotated_args = get_args(annotation)
|
||||
type_annotation = annotated_args[0]
|
||||
@@ -342,17 +374,20 @@ def analyze_param(
|
||||
if isinstance(arg, (params.Param, params.Body, params.Depends))
|
||||
]
|
||||
if fastapi_specific_annotations:
|
||||
fastapi_annotation: Union[
|
||||
FieldInfo, params.Depends, None
|
||||
] = fastapi_specific_annotations[-1]
|
||||
fastapi_annotation: Union[FieldInfo, params.Depends, None] = (
|
||||
fastapi_specific_annotations[-1]
|
||||
)
|
||||
else:
|
||||
fastapi_annotation = None
|
||||
# Set default for Annotated FieldInfo
|
||||
if isinstance(fastapi_annotation, FieldInfo):
|
||||
# Copy `field_info` because we mutate `field_info.default` below.
|
||||
field_info = copy_field_info(
|
||||
field_info=fastapi_annotation, annotation=use_annotation
|
||||
)
|
||||
assert field_info.default is Undefined or field_info.default is Required, (
|
||||
assert (
|
||||
field_info.default is Undefined or field_info.default is RequiredParam
|
||||
), (
|
||||
f"`{field_info.__class__.__name__}` default value cannot be set in"
|
||||
f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
|
||||
)
|
||||
@@ -360,10 +395,11 @@ def analyze_param(
|
||||
assert not is_path_param, "Path parameters cannot have default values"
|
||||
field_info.default = value
|
||||
else:
|
||||
field_info.default = Required
|
||||
field_info.default = RequiredParam
|
||||
# Get Annotated Depends
|
||||
elif isinstance(fastapi_annotation, params.Depends):
|
||||
depends = fastapi_annotation
|
||||
|
||||
# Get Depends from default value
|
||||
if isinstance(value, params.Depends):
|
||||
assert depends is None, (
|
||||
"Cannot specify `Depends` in `Annotated` and default value"
|
||||
@@ -374,6 +410,7 @@ def analyze_param(
|
||||
f" default value together for {param_name!r}"
|
||||
)
|
||||
depends = value
|
||||
# Get FieldInfo from default value
|
||||
elif isinstance(value, FieldInfo):
|
||||
assert field_info is None, (
|
||||
"Cannot specify FastAPI annotations in `Annotated` and default value"
|
||||
@@ -383,9 +420,13 @@ def analyze_param(
|
||||
if PYDANTIC_V2:
|
||||
field_info.annotation = type_annotation
|
||||
|
||||
# Get Depends from type annotation
|
||||
if depends is not None and depends.dependency is None:
|
||||
# Copy `depends` before mutating it
|
||||
depends = copy(depends)
|
||||
depends.dependency = type_annotation
|
||||
|
||||
# Handle non-param type annotations like Request
|
||||
if lenient_issubclass(
|
||||
type_annotation,
|
||||
(
|
||||
@@ -401,10 +442,11 @@ def analyze_param(
|
||||
assert (
|
||||
field_info is None
|
||||
), f"Cannot specify FastAPI annotation for type {type_annotation!r}"
|
||||
# Handle default assignations, neither field_info nor depends was not found in Annotated nor default value
|
||||
elif field_info is None and depends is None:
|
||||
default_value = value if value is not inspect.Signature.empty else Required
|
||||
default_value = value if value is not inspect.Signature.empty else RequiredParam
|
||||
if is_path_param:
|
||||
# We might check here that `default_value is Required`, but the fact is that the same
|
||||
# We might check here that `default_value is RequiredParam`, but the fact is that the same
|
||||
# parameter might sometimes be a path parameter and sometimes not. See
|
||||
# `tests/test_infer_param_optionality.py` for an example.
|
||||
field_info = params.Path(annotation=use_annotation)
|
||||
@@ -418,7 +460,9 @@ def analyze_param(
|
||||
field_info = params.Query(annotation=use_annotation, default=default_value)
|
||||
|
||||
field = None
|
||||
# It's a field_info, not a dependency
|
||||
if field_info is not None:
|
||||
# Handle field_info.in_
|
||||
if is_path_param:
|
||||
assert isinstance(field_info, params.Path), (
|
||||
f"Cannot use `{field_info.__class__.__name__}` for path param"
|
||||
@@ -434,40 +478,37 @@ def analyze_param(
|
||||
field_info,
|
||||
param_name,
|
||||
)
|
||||
if isinstance(field_info, params.Form):
|
||||
ensure_multipart_is_installed()
|
||||
if not field_info.alias and getattr(field_info, "convert_underscores", None):
|
||||
alias = param_name.replace("_", "-")
|
||||
else:
|
||||
alias = field_info.alias or param_name
|
||||
field_info.alias = alias
|
||||
field = create_response_field(
|
||||
field = create_model_field(
|
||||
name=param_name,
|
||||
type_=use_annotation_from_field_info,
|
||||
default=field_info.default,
|
||||
alias=alias,
|
||||
required=field_info.default in (Required, Undefined),
|
||||
required=field_info.default in (RequiredParam, Undefined),
|
||||
field_info=field_info,
|
||||
)
|
||||
if is_path_param:
|
||||
assert is_scalar_field(
|
||||
field=field
|
||||
), "Path params must be of one of the supported types"
|
||||
elif isinstance(field_info, params.Query):
|
||||
assert (
|
||||
is_scalar_field(field)
|
||||
or is_scalar_sequence_field(field)
|
||||
or (
|
||||
lenient_issubclass(field.type_, BaseModel)
|
||||
# For Pydantic v1
|
||||
and getattr(field, "shape", 1) == 1
|
||||
)
|
||||
)
|
||||
|
||||
return type_annotation, depends, field
|
||||
|
||||
|
||||
def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
|
||||
if is_path_param:
|
||||
assert is_scalar_field(
|
||||
field=param_field
|
||||
), "Path params must be of one of the supported types"
|
||||
return False
|
||||
elif is_scalar_field(field=param_field):
|
||||
return False
|
||||
elif isinstance(
|
||||
param_field.field_info, (params.Query, params.Header)
|
||||
) and is_scalar_sequence_field(param_field):
|
||||
return False
|
||||
else:
|
||||
assert isinstance(
|
||||
param_field.field_info, params.Body
|
||||
), f"Param: {param_field.name} can only be a request body, using Body()"
|
||||
return True
|
||||
return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
|
||||
|
||||
|
||||
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
|
||||
@@ -519,6 +560,15 @@ async def solve_generator(
|
||||
return await stack.enter_async_context(cm)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SolvedDependency:
|
||||
values: Dict[str, Any]
|
||||
errors: List[Any]
|
||||
background_tasks: Optional[StarletteBackgroundTasks]
|
||||
response: Response
|
||||
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
|
||||
|
||||
|
||||
async def solve_dependencies(
|
||||
*,
|
||||
request: Union[Request, WebSocket],
|
||||
@@ -529,13 +579,8 @@ async def solve_dependencies(
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
|
||||
async_exit_stack: AsyncExitStack,
|
||||
) -> Tuple[
|
||||
Dict[str, Any],
|
||||
List[Any],
|
||||
Optional[StarletteBackgroundTasks],
|
||||
Response,
|
||||
Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
|
||||
]:
|
||||
embed_body_fields: bool,
|
||||
) -> SolvedDependency:
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[Any] = []
|
||||
if response is None:
|
||||
@@ -576,28 +621,23 @@ async def solve_dependencies(
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
dependency_cache=dependency_cache,
|
||||
async_exit_stack=async_exit_stack,
|
||||
embed_body_fields=embed_body_fields,
|
||||
)
|
||||
(
|
||||
sub_values,
|
||||
sub_errors,
|
||||
background_tasks,
|
||||
_, # the subdependency returns the same response we have
|
||||
sub_dependency_cache,
|
||||
) = solved_result
|
||||
dependency_cache.update(sub_dependency_cache)
|
||||
if sub_errors:
|
||||
errors.extend(sub_errors)
|
||||
background_tasks = solved_result.background_tasks
|
||||
dependency_cache.update(solved_result.dependency_cache)
|
||||
if solved_result.errors:
|
||||
errors.extend(solved_result.errors)
|
||||
continue
|
||||
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
||||
solved = dependency_cache[sub_dependant.cache_key]
|
||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||
solved = await solve_generator(
|
||||
call=call, stack=async_exit_stack, sub_values=sub_values
|
||||
call=call, stack=async_exit_stack, sub_values=solved_result.values
|
||||
)
|
||||
elif is_coroutine_callable(call):
|
||||
solved = await call(**sub_values)
|
||||
solved = await call(**solved_result.values)
|
||||
else:
|
||||
solved = await run_in_threadpool(call, **sub_values)
|
||||
solved = await run_in_threadpool(call, **solved_result.values)
|
||||
if sub_dependant.name is not None:
|
||||
values[sub_dependant.name] = solved
|
||||
if sub_dependant.cache_key not in dependency_cache:
|
||||
@@ -624,7 +664,9 @@ async def solve_dependencies(
|
||||
body_values,
|
||||
body_errors,
|
||||
) = await request_body_to_args( # body_params checked above
|
||||
required_params=dependant.body_params, received_body=body
|
||||
body_fields=dependant.body_params,
|
||||
received_body=body,
|
||||
embed_body_fields=embed_body_fields,
|
||||
)
|
||||
values.update(body_values)
|
||||
errors.extend(body_errors)
|
||||
@@ -644,142 +686,257 @@ async def solve_dependencies(
|
||||
values[dependant.security_scopes_param_name] = SecurityScopes(
|
||||
scopes=dependant.security_scopes
|
||||
)
|
||||
return values, errors, background_tasks, response, dependency_cache
|
||||
return SolvedDependency(
|
||||
values=values,
|
||||
errors=errors,
|
||||
background_tasks=background_tasks,
|
||||
response=response,
|
||||
dependency_cache=dependency_cache,
|
||||
)
|
||||
|
||||
|
||||
def _validate_value_with_model_field(
|
||||
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
|
||||
) -> Tuple[Any, List[Any]]:
|
||||
if value is None:
|
||||
if field.required:
|
||||
return None, [get_missing_field_error(loc=loc)]
|
||||
else:
|
||||
return deepcopy(field.default), []
|
||||
v_, errors_ = field.validate(value, values, loc=loc)
|
||||
if isinstance(errors_, ErrorWrapper):
|
||||
return None, [errors_]
|
||||
elif isinstance(errors_, list):
|
||||
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
||||
return None, new_errors
|
||||
else:
|
||||
return v_, []
|
||||
|
||||
|
||||
def _get_multidict_value(
|
||||
field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
|
||||
) -> Any:
|
||||
alias = alias or field.alias
|
||||
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
|
||||
value = values.getlist(alias)
|
||||
else:
|
||||
value = values.get(alias, None)
|
||||
if (
|
||||
value is None
|
||||
or (
|
||||
isinstance(field.field_info, params.Form)
|
||||
and isinstance(value, str) # For type checks
|
||||
and value == ""
|
||||
)
|
||||
or (is_sequence_field(field) and len(value) == 0)
|
||||
):
|
||||
if field.required:
|
||||
return
|
||||
else:
|
||||
return deepcopy(field.default)
|
||||
return value
|
||||
|
||||
|
||||
def request_params_to_args(
|
||||
required_params: Sequence[ModelField],
|
||||
fields: Sequence[ModelField],
|
||||
received_params: Union[Mapping[str, Any], QueryParams, Headers],
|
||||
) -> Tuple[Dict[str, Any], List[Any]]:
|
||||
values = {}
|
||||
errors = []
|
||||
for field in required_params:
|
||||
if is_scalar_sequence_field(field) and isinstance(
|
||||
received_params, (QueryParams, Headers)
|
||||
):
|
||||
value = received_params.getlist(field.alias) or field.default
|
||||
else:
|
||||
value = received_params.get(field.alias)
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[Dict[str, Any]] = []
|
||||
|
||||
if not fields:
|
||||
return values, errors
|
||||
|
||||
first_field = fields[0]
|
||||
fields_to_extract = fields
|
||||
single_not_embedded_field = False
|
||||
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
|
||||
fields_to_extract = get_cached_model_fields(first_field.type_)
|
||||
single_not_embedded_field = True
|
||||
|
||||
params_to_process: Dict[str, Any] = {}
|
||||
|
||||
processed_keys = set()
|
||||
|
||||
for field in fields_to_extract:
|
||||
alias = None
|
||||
if isinstance(received_params, Headers):
|
||||
# Handle fields extracted from a Pydantic Model for a header, each field
|
||||
# doesn't have a FieldInfo of type Header with the default convert_underscores=True
|
||||
convert_underscores = getattr(field.field_info, "convert_underscores", True)
|
||||
if convert_underscores:
|
||||
alias = (
|
||||
field.alias
|
||||
if field.alias != field.name
|
||||
else field.name.replace("_", "-")
|
||||
)
|
||||
value = _get_multidict_value(field, received_params, alias=alias)
|
||||
if value is not None:
|
||||
params_to_process[field.name] = value
|
||||
processed_keys.add(alias or field.alias)
|
||||
processed_keys.add(field.name)
|
||||
|
||||
for key, value in received_params.items():
|
||||
if key not in processed_keys:
|
||||
params_to_process[key] = value
|
||||
|
||||
if single_not_embedded_field:
|
||||
field_info = first_field.field_info
|
||||
assert isinstance(
|
||||
field_info, params.Param
|
||||
), "Params must be subclasses of Param"
|
||||
loc: Tuple[str, ...] = (field_info.in_.value,)
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=first_field, value=params_to_process, values=values, loc=loc
|
||||
)
|
||||
return {first_field.name: v_}, errors_
|
||||
|
||||
for field in fields:
|
||||
value = _get_multidict_value(field, received_params)
|
||||
field_info = field.field_info
|
||||
assert isinstance(
|
||||
field_info, params.Param
|
||||
), "Params must be subclasses of Param"
|
||||
loc = (field_info.in_.value, field.alias)
|
||||
if value is None:
|
||||
if field.required:
|
||||
errors.append(get_missing_field_error(loc=loc))
|
||||
else:
|
||||
values[field.name] = deepcopy(field.default)
|
||||
continue
|
||||
v_, errors_ = field.validate(value, values, loc=loc)
|
||||
if isinstance(errors_, ErrorWrapper):
|
||||
errors.append(errors_)
|
||||
elif isinstance(errors_, list):
|
||||
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
||||
errors.extend(new_errors)
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=field, value=value, values=values, loc=loc
|
||||
)
|
||||
if errors_:
|
||||
errors.extend(errors_)
|
||||
else:
|
||||
values[field.name] = v_
|
||||
return values, errors
|
||||
|
||||
|
||||
async def request_body_to_args(
|
||||
required_params: List[ModelField],
|
||||
received_body: Optional[Union[Dict[str, Any], FormData]],
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
def _should_embed_body_fields(fields: List[ModelField]) -> bool:
|
||||
if not fields:
|
||||
return False
|
||||
# More than one dependency could have the same field, it would show up as multiple
|
||||
# fields but it's the same one, so count them by name
|
||||
body_param_names_set = {field.name for field in fields}
|
||||
# A top level field has to be a single field, not multiple
|
||||
if len(body_param_names_set) > 1:
|
||||
return True
|
||||
first_field = fields[0]
|
||||
# If it explicitly specifies it is embedded, it has to be embedded
|
||||
if getattr(first_field.field_info, "embed", None):
|
||||
return True
|
||||
# If it's a Form (or File) field, it has to be a BaseModel to be top level
|
||||
# otherwise it has to be embedded, so that the key value pair can be extracted
|
||||
if isinstance(first_field.field_info, params.Form) and not lenient_issubclass(
|
||||
first_field.type_, BaseModel
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _extract_form_body(
|
||||
body_fields: List[ModelField],
|
||||
received_body: FormData,
|
||||
) -> Dict[str, Any]:
|
||||
values = {}
|
||||
first_field = body_fields[0]
|
||||
first_field_info = first_field.field_info
|
||||
|
||||
for field in body_fields:
|
||||
value = _get_multidict_value(field, received_body)
|
||||
if (
|
||||
isinstance(first_field_info, params.File)
|
||||
and is_bytes_field(field)
|
||||
and isinstance(value, UploadFile)
|
||||
):
|
||||
value = await value.read()
|
||||
elif (
|
||||
is_bytes_sequence_field(field)
|
||||
and isinstance(first_field_info, params.File)
|
||||
and value_is_sequence(value)
|
||||
):
|
||||
# For types
|
||||
assert isinstance(value, sequence_types) # type: ignore[arg-type]
|
||||
results: List[Union[bytes, str]] = []
|
||||
|
||||
async def process_fn(
|
||||
fn: Callable[[], Coroutine[Any, Any, Any]],
|
||||
) -> None:
|
||||
result = await fn()
|
||||
results.append(result) # noqa: B023
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for sub_value in value:
|
||||
tg.start_soon(process_fn, sub_value.read)
|
||||
value = serialize_sequence_value(field=field, value=results)
|
||||
if value is not None:
|
||||
values[field.alias] = value
|
||||
for key, value in received_body.items():
|
||||
if key not in values:
|
||||
values[key] = value
|
||||
return values
|
||||
|
||||
|
||||
async def request_body_to_args(
|
||||
body_fields: List[ModelField],
|
||||
received_body: Optional[Union[Dict[str, Any], FormData]],
|
||||
embed_body_fields: bool,
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[Dict[str, Any]] = []
|
||||
if required_params:
|
||||
field = required_params[0]
|
||||
field_info = field.field_info
|
||||
embed = getattr(field_info, "embed", None)
|
||||
field_alias_omitted = len(required_params) == 1 and not embed
|
||||
if field_alias_omitted:
|
||||
received_body = {field.alias: received_body}
|
||||
assert body_fields, "request_body_to_args() should be called with fields"
|
||||
single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
|
||||
first_field = body_fields[0]
|
||||
body_to_process = received_body
|
||||
|
||||
for field in required_params:
|
||||
loc: Tuple[str, ...]
|
||||
if field_alias_omitted:
|
||||
loc = ("body",)
|
||||
else:
|
||||
loc = ("body", field.alias)
|
||||
fields_to_extract: List[ModelField] = body_fields
|
||||
|
||||
value: Optional[Any] = None
|
||||
if received_body is not None:
|
||||
if (is_sequence_field(field)) and isinstance(received_body, FormData):
|
||||
value = received_body.getlist(field.alias)
|
||||
else:
|
||||
try:
|
||||
value = received_body.get(field.alias)
|
||||
except AttributeError:
|
||||
errors.append(get_missing_field_error(loc))
|
||||
continue
|
||||
if (
|
||||
value is None
|
||||
or (isinstance(field_info, params.Form) and value == "")
|
||||
or (
|
||||
isinstance(field_info, params.Form)
|
||||
and is_sequence_field(field)
|
||||
and len(value) == 0
|
||||
)
|
||||
):
|
||||
if field.required:
|
||||
errors.append(get_missing_field_error(loc))
|
||||
else:
|
||||
values[field.name] = deepcopy(field.default)
|
||||
if single_not_embedded_field and lenient_issubclass(first_field.type_, BaseModel):
|
||||
fields_to_extract = get_cached_model_fields(first_field.type_)
|
||||
|
||||
if isinstance(received_body, FormData):
|
||||
body_to_process = await _extract_form_body(fields_to_extract, received_body)
|
||||
|
||||
if single_not_embedded_field:
|
||||
loc: Tuple[str, ...] = ("body",)
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=first_field, value=body_to_process, values=values, loc=loc
|
||||
)
|
||||
return {first_field.name: v_}, errors_
|
||||
for field in body_fields:
|
||||
loc = ("body", field.alias)
|
||||
value: Optional[Any] = None
|
||||
if body_to_process is not None:
|
||||
try:
|
||||
value = body_to_process.get(field.alias)
|
||||
# If the received body is a list, not a dict
|
||||
except AttributeError:
|
||||
errors.append(get_missing_field_error(loc))
|
||||
continue
|
||||
if (
|
||||
isinstance(field_info, params.File)
|
||||
and is_bytes_field(field)
|
||||
and isinstance(value, UploadFile)
|
||||
):
|
||||
value = await value.read()
|
||||
elif (
|
||||
is_bytes_sequence_field(field)
|
||||
and isinstance(field_info, params.File)
|
||||
and value_is_sequence(value)
|
||||
):
|
||||
# For types
|
||||
assert isinstance(value, sequence_types) # type: ignore[arg-type]
|
||||
results: List[Union[bytes, str]] = []
|
||||
|
||||
async def process_fn(
|
||||
fn: Callable[[], Coroutine[Any, Any, Any]]
|
||||
) -> None:
|
||||
result = await fn()
|
||||
results.append(result) # noqa: B023
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for sub_value in value:
|
||||
tg.start_soon(process_fn, sub_value.read)
|
||||
value = serialize_sequence_value(field=field, value=results)
|
||||
|
||||
v_, errors_ = field.validate(value, values, loc=loc)
|
||||
|
||||
if isinstance(errors_, list):
|
||||
errors.extend(errors_)
|
||||
elif errors_:
|
||||
errors.append(errors_)
|
||||
else:
|
||||
values[field.name] = v_
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=field, value=value, values=values, loc=loc
|
||||
)
|
||||
if errors_:
|
||||
errors.extend(errors_)
|
||||
else:
|
||||
values[field.name] = v_
|
||||
return values, errors
|
||||
|
||||
|
||||
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
|
||||
flat_dependant = get_flat_dependant(dependant)
|
||||
def get_body_field(
|
||||
*, flat_dependant: Dependant, name: str, embed_body_fields: bool
|
||||
) -> Optional[ModelField]:
|
||||
"""
|
||||
Get a ModelField representing the request body for a path operation, combining
|
||||
all body parameters into a single field if necessary.
|
||||
|
||||
Used to check if it's form data (with `isinstance(body_field, params.Form)`)
|
||||
or JSON and to generate the JSON Schema for a request body.
|
||||
|
||||
This is **not** used to validate/parse the request body, that's done with each
|
||||
individual body parameter.
|
||||
"""
|
||||
if not flat_dependant.body_params:
|
||||
return None
|
||||
first_param = flat_dependant.body_params[0]
|
||||
field_info = first_param.field_info
|
||||
embed = getattr(field_info, "embed", None)
|
||||
body_param_names_set = {param.name for param in flat_dependant.body_params}
|
||||
if len(body_param_names_set) == 1 and not embed:
|
||||
check_file_field(first_param)
|
||||
if not embed_body_fields:
|
||||
return first_param
|
||||
# If one field requires to embed, all have to be embedded
|
||||
# in case a sub-dependency is evaluated with a single unique body field
|
||||
# That is combined (embedded) with other body fields
|
||||
for param in flat_dependant.body_params:
|
||||
setattr(param.field_info, "embed", True) # noqa: B010
|
||||
model_name = "Body_" + name
|
||||
BodyModel = create_body_model(
|
||||
fields=flat_dependant.body_params, model_name=model_name
|
||||
@@ -805,12 +962,11 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
|
||||
]
|
||||
if len(set(body_param_media_types)) == 1:
|
||||
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
|
||||
final_field = create_response_field(
|
||||
final_field = create_model_field(
|
||||
name="body",
|
||||
type_=BodyModel,
|
||||
required=required,
|
||||
alias="body",
|
||||
field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
|
||||
)
|
||||
check_file_field(final_field)
|
||||
return final_field
|
||||
|
@@ -22,9 +22,9 @@ from pydantic import BaseModel
|
||||
from pydantic.color import Color
|
||||
from pydantic.networks import AnyUrl, NameEmail
|
||||
from pydantic.types import SecretBytes, SecretStr
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated, Doc
|
||||
|
||||
from ._compat import PYDANTIC_V2, Url, _model_dump
|
||||
from ._compat import PYDANTIC_V2, UndefinedType, Url, _model_dump
|
||||
|
||||
|
||||
# Taken from Pydantic v1 as is
|
||||
@@ -86,7 +86,7 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
|
||||
|
||||
|
||||
def generate_encoders_by_class_tuples(
|
||||
type_encoder_map: Dict[Any, Callable[[Any], Any]]
|
||||
type_encoder_map: Dict[Any, Callable[[Any], Any]],
|
||||
) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]:
|
||||
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(
|
||||
tuple
|
||||
@@ -259,6 +259,8 @@ def jsonable_encoder(
|
||||
return str(obj)
|
||||
if isinstance(obj, (str, int, float, type(None))):
|
||||
return obj
|
||||
if isinstance(obj, UndefinedType):
|
||||
return None
|
||||
if isinstance(obj, dict):
|
||||
encoded_dict = {}
|
||||
allowed_keys = set(obj.keys())
|
||||
|
@@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Sequence, Type, Union
|
||||
from pydantic import BaseModel, create_model
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from starlette.exceptions import WebSocketException as StarletteWebSocketException
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated, Doc
|
||||
|
||||
|
||||
class HTTPException(StarletteHTTPException):
|
||||
|
@@ -3,7 +3,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from starlette.responses import HTMLResponse
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated, Doc
|
||||
|
||||
swagger_ui_default_parameters: Annotated[
|
||||
Dict[str, Any],
|
||||
@@ -53,7 +53,7 @@ def get_swagger_ui_html(
|
||||
It is normally set to a CDN URL.
|
||||
"""
|
||||
),
|
||||
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui-bundle.js",
|
||||
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js",
|
||||
swagger_css_url: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
@@ -63,7 +63,7 @@ def get_swagger_ui_html(
|
||||
It is normally set to a CDN URL.
|
||||
"""
|
||||
),
|
||||
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui.css",
|
||||
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css",
|
||||
swagger_favicon_url: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
|
@@ -55,35 +55,29 @@ except ImportError: # pragma: no cover
|
||||
return with_info_plain_validator_function(cls._validate)
|
||||
|
||||
|
||||
class Contact(BaseModel):
|
||||
class BaseModelWithConfig(BaseModel):
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Contact(BaseModelWithConfig):
|
||||
name: Optional[str] = None
|
||||
url: Optional[AnyUrl] = None
|
||||
email: Optional[EmailStr] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class License(BaseModel):
|
||||
class License(BaseModelWithConfig):
|
||||
name: str
|
||||
identifier: Optional[str] = None
|
||||
url: Optional[AnyUrl] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Info(BaseModel):
|
||||
class Info(BaseModelWithConfig):
|
||||
title: str
|
||||
summary: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
@@ -92,42 +86,18 @@ class Info(BaseModel):
|
||||
license: Optional[License] = None
|
||||
version: str
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class ServerVariable(BaseModel):
|
||||
class ServerVariable(BaseModelWithConfig):
|
||||
enum: Annotated[Optional[List[str]], Field(min_length=1)] = None
|
||||
default: str
|
||||
description: Optional[str] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Server(BaseModel):
|
||||
class Server(BaseModelWithConfig):
|
||||
url: Union[AnyUrl, str]
|
||||
description: Optional[str] = None
|
||||
variables: Optional[Dict[str, ServerVariable]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Reference(BaseModel):
|
||||
ref: str = Field(alias="$ref")
|
||||
@@ -138,36 +108,20 @@ class Discriminator(BaseModel):
|
||||
mapping: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
class XML(BaseModel):
|
||||
class XML(BaseModelWithConfig):
|
||||
name: Optional[str] = None
|
||||
namespace: Optional[str] = None
|
||||
prefix: Optional[str] = None
|
||||
attribute: Optional[bool] = None
|
||||
wrapped: Optional[bool] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class ExternalDocumentation(BaseModel):
|
||||
class ExternalDocumentation(BaseModelWithConfig):
|
||||
description: Optional[str] = None
|
||||
url: AnyUrl
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Schema(BaseModel):
|
||||
class Schema(BaseModelWithConfig):
|
||||
# Ref: JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-core.html#name-the-json-schema-core-vocabu
|
||||
# Core Vocabulary
|
||||
schema_: Optional[str] = Field(default=None, alias="$schema")
|
||||
@@ -253,14 +207,6 @@ class Schema(BaseModel):
|
||||
),
|
||||
] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
# Ref: https://json-schema.org/draft/2020-12/json-schema-core.html#name-json-schema-documents
|
||||
# A JSON Schema MUST be an object or a boolean.
|
||||
@@ -289,38 +235,22 @@ class ParameterInType(Enum):
|
||||
cookie = "cookie"
|
||||
|
||||
|
||||
class Encoding(BaseModel):
|
||||
class Encoding(BaseModelWithConfig):
|
||||
contentType: Optional[str] = None
|
||||
headers: Optional[Dict[str, Union["Header", Reference]]] = None
|
||||
style: Optional[str] = None
|
||||
explode: Optional[bool] = None
|
||||
allowReserved: Optional[bool] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class MediaType(BaseModel):
|
||||
class MediaType(BaseModelWithConfig):
|
||||
schema_: Optional[Union[Schema, Reference]] = Field(default=None, alias="schema")
|
||||
example: Optional[Any] = None
|
||||
examples: Optional[Dict[str, Union[Example, Reference]]] = None
|
||||
encoding: Optional[Dict[str, Encoding]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class ParameterBase(BaseModel):
|
||||
class ParameterBase(BaseModelWithConfig):
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = None
|
||||
deprecated: Optional[bool] = None
|
||||
@@ -334,14 +264,6 @@ class ParameterBase(BaseModel):
|
||||
# Serialization rules for more complex scenarios
|
||||
content: Optional[Dict[str, MediaType]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Parameter(ParameterBase):
|
||||
name: str
|
||||
@@ -352,21 +274,13 @@ class Header(ParameterBase):
|
||||
pass
|
||||
|
||||
|
||||
class RequestBody(BaseModel):
|
||||
class RequestBody(BaseModelWithConfig):
|
||||
description: Optional[str] = None
|
||||
content: Dict[str, MediaType]
|
||||
required: Optional[bool] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Link(BaseModel):
|
||||
class Link(BaseModelWithConfig):
|
||||
operationRef: Optional[str] = None
|
||||
operationId: Optional[str] = None
|
||||
parameters: Optional[Dict[str, Union[Any, str]]] = None
|
||||
@@ -374,31 +288,15 @@ class Link(BaseModel):
|
||||
description: Optional[str] = None
|
||||
server: Optional[Server] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
class Response(BaseModelWithConfig):
|
||||
description: str
|
||||
headers: Optional[Dict[str, Union[Header, Reference]]] = None
|
||||
content: Optional[Dict[str, MediaType]] = None
|
||||
links: Optional[Dict[str, Union[Link, Reference]]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Operation(BaseModel):
|
||||
class Operation(BaseModelWithConfig):
|
||||
tags: Optional[List[str]] = None
|
||||
summary: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
@@ -413,16 +311,8 @@ class Operation(BaseModel):
|
||||
security: Optional[List[Dict[str, List[str]]]] = None
|
||||
servers: Optional[List[Server]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class PathItem(BaseModel):
|
||||
class PathItem(BaseModelWithConfig):
|
||||
ref: Optional[str] = Field(default=None, alias="$ref")
|
||||
summary: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
@@ -437,14 +327,6 @@ class PathItem(BaseModel):
|
||||
servers: Optional[List[Server]] = None
|
||||
parameters: Optional[List[Union[Parameter, Reference]]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class SecuritySchemeType(Enum):
|
||||
apiKey = "apiKey"
|
||||
@@ -453,18 +335,10 @@ class SecuritySchemeType(Enum):
|
||||
openIdConnect = "openIdConnect"
|
||||
|
||||
|
||||
class SecurityBase(BaseModel):
|
||||
class SecurityBase(BaseModelWithConfig):
|
||||
type_: SecuritySchemeType = Field(alias="type")
|
||||
description: Optional[str] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class APIKeyIn(Enum):
|
||||
query = "query"
|
||||
@@ -488,18 +362,10 @@ class HTTPBearer(HTTPBase):
|
||||
bearerFormat: Optional[str] = None
|
||||
|
||||
|
||||
class OAuthFlow(BaseModel):
|
||||
class OAuthFlow(BaseModelWithConfig):
|
||||
refreshUrl: Optional[str] = None
|
||||
scopes: Dict[str, str] = {}
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class OAuthFlowImplicit(OAuthFlow):
|
||||
authorizationUrl: str
|
||||
@@ -518,20 +384,12 @@ class OAuthFlowAuthorizationCode(OAuthFlow):
|
||||
tokenUrl: str
|
||||
|
||||
|
||||
class OAuthFlows(BaseModel):
|
||||
class OAuthFlows(BaseModelWithConfig):
|
||||
implicit: Optional[OAuthFlowImplicit] = None
|
||||
password: Optional[OAuthFlowPassword] = None
|
||||
clientCredentials: Optional[OAuthFlowClientCredentials] = None
|
||||
authorizationCode: Optional[OAuthFlowAuthorizationCode] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class OAuth2(SecurityBase):
|
||||
type_: SecuritySchemeType = Field(default=SecuritySchemeType.oauth2, alias="type")
|
||||
@@ -548,7 +406,7 @@ class OpenIdConnect(SecurityBase):
|
||||
SecurityScheme = Union[APIKey, HTTPBase, OAuth2, OpenIdConnect, HTTPBearer]
|
||||
|
||||
|
||||
class Components(BaseModel):
|
||||
class Components(BaseModelWithConfig):
|
||||
schemas: Optional[Dict[str, Union[Schema, Reference]]] = None
|
||||
responses: Optional[Dict[str, Union[Response, Reference]]] = None
|
||||
parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None
|
||||
@@ -561,30 +419,14 @@ class Components(BaseModel):
|
||||
callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference, Any]]] = None
|
||||
pathItems: Optional[Dict[str, Union[PathItem, Reference]]] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Tag(BaseModel):
|
||||
class Tag(BaseModelWithConfig):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
externalDocs: Optional[ExternalDocumentation] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class OpenAPI(BaseModel):
|
||||
class OpenAPI(BaseModelWithConfig):
|
||||
openapi: str
|
||||
info: Info
|
||||
jsonSchemaDialect: Optional[str] = None
|
||||
@@ -597,14 +439,6 @@ class OpenAPI(BaseModel):
|
||||
tags: Optional[List[Tag]] = None
|
||||
externalDocs: Optional[ExternalDocumentation] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
else:
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
_model_rebuild(Schema)
|
||||
_model_rebuild(Operation)
|
||||
|
@@ -16,11 +16,15 @@ from fastapi._compat import (
|
||||
)
|
||||
from fastapi.datastructures import DefaultPlaceholder
|
||||
from fastapi.dependencies.models import Dependant
|
||||
from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
|
||||
from fastapi.dependencies.utils import (
|
||||
_get_flat_fields_from_params,
|
||||
get_flat_dependant,
|
||||
get_flat_params,
|
||||
)
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX, REF_TEMPLATE
|
||||
from fastapi.openapi.models import OpenAPI
|
||||
from fastapi.params import Body, Param
|
||||
from fastapi.params import Body, ParamTypes
|
||||
from fastapi.responses import Response
|
||||
from fastapi.types import ModelNameMap
|
||||
from fastapi.utils import (
|
||||
@@ -87,9 +91,9 @@ def get_openapi_security_definitions(
|
||||
return security_definitions, operation_security
|
||||
|
||||
|
||||
def get_openapi_operation_parameters(
|
||||
def _get_openapi_operation_parameters(
|
||||
*,
|
||||
all_route_params: Sequence[ModelField],
|
||||
dependant: Dependant,
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
@@ -98,33 +102,47 @@ def get_openapi_operation_parameters(
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> List[Dict[str, Any]]:
|
||||
parameters = []
|
||||
for param in all_route_params:
|
||||
field_info = param.field_info
|
||||
field_info = cast(Param, field_info)
|
||||
if not field_info.include_in_schema:
|
||||
continue
|
||||
param_schema = get_schema_from_model_field(
|
||||
field=param,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
)
|
||||
parameter = {
|
||||
"name": param.alias,
|
||||
"in": field_info.in_.value,
|
||||
"required": param.required,
|
||||
"schema": param_schema,
|
||||
}
|
||||
if field_info.description:
|
||||
parameter["description"] = field_info.description
|
||||
if field_info.openapi_examples:
|
||||
parameter["examples"] = jsonable_encoder(field_info.openapi_examples)
|
||||
elif field_info.example != Undefined:
|
||||
parameter["example"] = jsonable_encoder(field_info.example)
|
||||
if field_info.deprecated:
|
||||
parameter["deprecated"] = field_info.deprecated
|
||||
parameters.append(parameter)
|
||||
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
|
||||
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
|
||||
query_params = _get_flat_fields_from_params(flat_dependant.query_params)
|
||||
header_params = _get_flat_fields_from_params(flat_dependant.header_params)
|
||||
cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
|
||||
parameter_groups = [
|
||||
(ParamTypes.path, path_params),
|
||||
(ParamTypes.query, query_params),
|
||||
(ParamTypes.header, header_params),
|
||||
(ParamTypes.cookie, cookie_params),
|
||||
]
|
||||
for param_type, param_group in parameter_groups:
|
||||
for param in param_group:
|
||||
field_info = param.field_info
|
||||
# field_info = cast(Param, field_info)
|
||||
if not getattr(field_info, "include_in_schema", True):
|
||||
continue
|
||||
param_schema = get_schema_from_model_field(
|
||||
field=param,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
)
|
||||
parameter = {
|
||||
"name": param.alias,
|
||||
"in": param_type.value,
|
||||
"required": param.required,
|
||||
"schema": param_schema,
|
||||
}
|
||||
if field_info.description:
|
||||
parameter["description"] = field_info.description
|
||||
openapi_examples = getattr(field_info, "openapi_examples", None)
|
||||
example = getattr(field_info, "example", None)
|
||||
if openapi_examples:
|
||||
parameter["examples"] = jsonable_encoder(openapi_examples)
|
||||
elif example != Undefined:
|
||||
parameter["example"] = jsonable_encoder(example)
|
||||
if getattr(field_info, "deprecated", None):
|
||||
parameter["deprecated"] = True
|
||||
parameters.append(parameter)
|
||||
return parameters
|
||||
|
||||
|
||||
@@ -247,9 +265,8 @@ def get_openapi_path(
|
||||
operation.setdefault("security", []).extend(operation_security)
|
||||
if security_definitions:
|
||||
security_schemes.update(security_definitions)
|
||||
all_route_params = get_flat_params(route.dependant)
|
||||
operation_parameters = get_openapi_operation_parameters(
|
||||
all_route_params=all_route_params,
|
||||
operation_parameters = _get_openapi_operation_parameters(
|
||||
dependant=route.dependant,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
@@ -379,6 +396,7 @@ def get_openapi_path(
|
||||
deep_dict_update(openapi_response, process_response)
|
||||
openapi_response["description"] = description
|
||||
http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
all_route_params = get_flat_params(route.dependant)
|
||||
if (all_route_params or route.body_field) and not any(
|
||||
status in operation["responses"]
|
||||
for status in [http422, "4XX", "default"]
|
||||
|
@@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||
from fastapi import params
|
||||
from fastapi._compat import Undefined
|
||||
from fastapi.openapi.models import Example
|
||||
from typing_extensions import Annotated, Doc, deprecated # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated, Doc, deprecated
|
||||
|
||||
_Unset: Any = Undefined
|
||||
|
||||
@@ -240,7 +240,7 @@ def Path( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Optional[bool],
|
||||
Union[deprecated, str, bool, None],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -565,7 +565,7 @@ def Query( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Optional[bool],
|
||||
Union[deprecated, str, bool, None],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -880,7 +880,7 @@ def Header( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Optional[bool],
|
||||
Union[deprecated, str, bool, None],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -1185,7 +1185,7 @@ def Cookie( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Optional[bool],
|
||||
Union[deprecated, str, bool, None],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -1282,7 +1282,7 @@ def Body( # noqa: N802
|
||||
),
|
||||
] = _Unset,
|
||||
embed: Annotated[
|
||||
bool,
|
||||
Union[bool, None],
|
||||
Doc(
|
||||
"""
|
||||
When `embed` is `True`, the parameter will be expected in a JSON body as a
|
||||
@@ -1294,7 +1294,7 @@ def Body( # noqa: N802
|
||||
[FastAPI docs for Body - Multiple Parameters](https://fastapi.tiangolo.com/tutorial/body-multiple-params/#embed-a-single-body-parameter).
|
||||
"""
|
||||
),
|
||||
] = False,
|
||||
] = None,
|
||||
media_type: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
@@ -1512,7 +1512,7 @@ def Body( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Optional[bool],
|
||||
Union[deprecated, str, bool, None],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -1827,7 +1827,7 @@ def Form( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Optional[bool],
|
||||
Union[deprecated, str, bool, None],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -2141,7 +2141,7 @@ def File( # noqa: N802
|
||||
),
|
||||
] = None,
|
||||
deprecated: Annotated[
|
||||
Optional[bool],
|
||||
Union[deprecated, str, bool, None],
|
||||
Doc(
|
||||
"""
|
||||
Mark this parameter field as deprecated.
|
||||
@@ -2298,7 +2298,7 @@ def Security( # noqa: N802
|
||||
dependency.
|
||||
|
||||
The term "scope" comes from the OAuth2 specification, it seems to be
|
||||
intentionaly vague and interpretable. It normally refers to permissions,
|
||||
intentionally vague and interpretable. It normally refers to permissions,
|
||||
in cases to roles.
|
||||
|
||||
These scopes are integrated with OpenAPI (and the API docs at `/docs`).
|
||||
@@ -2343,7 +2343,7 @@ def Security( # noqa: N802
|
||||
```python
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi import Security, FastAPI
|
||||
|
||||
from .db import User
|
||||
from .security import get_current_active_user
|
||||
|
@@ -6,7 +6,11 @@ from fastapi.openapi.models import Example
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import Annotated, deprecated
|
||||
|
||||
from ._compat import PYDANTIC_V2, Undefined
|
||||
from ._compat import (
|
||||
PYDANTIC_V2,
|
||||
PYDANTIC_VERSION_MINOR_TUPLE,
|
||||
Undefined,
|
||||
)
|
||||
|
||||
_Unset: Any = Undefined
|
||||
|
||||
@@ -63,12 +67,11 @@ class Param(FieldInfo):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
):
|
||||
self.deprecated = deprecated
|
||||
if example is not _Unset:
|
||||
warnings.warn(
|
||||
"`example` has been deprecated, please use `examples` instead",
|
||||
@@ -92,7 +95,7 @@ class Param(FieldInfo):
|
||||
max_length=max_length,
|
||||
discriminator=discriminator,
|
||||
multiple_of=multiple_of,
|
||||
allow_nan=allow_inf_nan,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
**extra,
|
||||
@@ -106,6 +109,10 @@ class Param(FieldInfo):
|
||||
stacklevel=4,
|
||||
)
|
||||
current_json_schema_extra = json_schema_extra or extra
|
||||
if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7):
|
||||
self.deprecated = deprecated
|
||||
else:
|
||||
kwargs["deprecated"] = deprecated
|
||||
if PYDANTIC_V2:
|
||||
kwargs.update(
|
||||
{
|
||||
@@ -174,7 +181,7 @@ class Path(Param):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
@@ -260,7 +267,7 @@ class Query(Param):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
@@ -345,7 +352,7 @@ class Header(Param):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
@@ -430,7 +437,7 @@ class Cookie(Param):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
@@ -476,7 +483,7 @@ class Body(FieldInfo):
|
||||
*,
|
||||
default_factory: Union[Callable[[], Any], None] = _Unset,
|
||||
annotation: Optional[Any] = None,
|
||||
embed: bool = False,
|
||||
embed: Union[bool, None] = None,
|
||||
media_type: str = "application/json",
|
||||
alias: Optional[str] = None,
|
||||
alias_priority: Union[int, None] = _Unset,
|
||||
@@ -514,14 +521,13 @@ class Body(FieldInfo):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
):
|
||||
self.embed = embed
|
||||
self.media_type = media_type
|
||||
self.deprecated = deprecated
|
||||
if example is not _Unset:
|
||||
warnings.warn(
|
||||
"`example` has been deprecated, please use `examples` instead",
|
||||
@@ -545,7 +551,7 @@ class Body(FieldInfo):
|
||||
max_length=max_length,
|
||||
discriminator=discriminator,
|
||||
multiple_of=multiple_of,
|
||||
allow_nan=allow_inf_nan,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
**extra,
|
||||
@@ -554,11 +560,15 @@ class Body(FieldInfo):
|
||||
kwargs["examples"] = examples
|
||||
if regex is not None:
|
||||
warnings.warn(
|
||||
"`regex` has been depreacated, please use `pattern` instead",
|
||||
"`regex` has been deprecated, please use `pattern` instead",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
current_json_schema_extra = json_schema_extra or extra
|
||||
if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7):
|
||||
self.deprecated = deprecated
|
||||
else:
|
||||
kwargs["deprecated"] = deprecated
|
||||
if PYDANTIC_V2:
|
||||
kwargs.update(
|
||||
{
|
||||
@@ -627,7 +637,7 @@ class Form(Body):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
@@ -636,7 +646,6 @@ class Form(Body):
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
annotation=annotation,
|
||||
embed=True,
|
||||
media_type=media_type,
|
||||
alias=alias,
|
||||
alias_priority=alias_priority,
|
||||
@@ -712,7 +721,7 @@ class File(Form):
|
||||
),
|
||||
] = _Unset,
|
||||
openapi_examples: Optional[Dict[str, Example]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
deprecated: Union[deprecated, str, bool, None] = None,
|
||||
include_in_schema: bool = True,
|
||||
json_schema_extra: Union[Dict[str, Any], None] = None,
|
||||
**extra: Any,
|
||||
|
@@ -3,14 +3,16 @@ import dataclasses
|
||||
import email.message
|
||||
import inspect
|
||||
import json
|
||||
from contextlib import AsyncExitStack
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from enum import Enum, IntEnum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
@@ -31,8 +33,10 @@ from fastapi._compat import (
|
||||
from fastapi.datastructures import Default, DefaultPlaceholder
|
||||
from fastapi.dependencies.models import Dependant
|
||||
from fastapi.dependencies.utils import (
|
||||
_should_embed_body_fields,
|
||||
get_body_field,
|
||||
get_dependant,
|
||||
get_flat_dependant,
|
||||
get_parameterless_sub_dependant,
|
||||
get_typed_return_annotation,
|
||||
solve_dependencies,
|
||||
@@ -47,7 +51,7 @@ from fastapi.exceptions import (
|
||||
from fastapi.types import DecoratedCallable, IncEx
|
||||
from fastapi.utils import (
|
||||
create_cloned_field,
|
||||
create_response_field,
|
||||
create_model_field,
|
||||
generate_unique_id,
|
||||
get_value_or_default,
|
||||
is_body_allowed_for_status_code,
|
||||
@@ -67,9 +71,9 @@ from starlette.routing import (
|
||||
websocket_session,
|
||||
)
|
||||
from starlette.routing import Mount as Mount # noqa
|
||||
from starlette.types import ASGIApp, Lifespan, Scope
|
||||
from starlette.types import AppType, ASGIApp, Lifespan, Scope
|
||||
from starlette.websockets import WebSocket
|
||||
from typing_extensions import Annotated, Doc, deprecated # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated, Doc, deprecated
|
||||
|
||||
|
||||
def _prepare_response_content(
|
||||
@@ -119,6 +123,23 @@ def _prepare_response_content(
|
||||
return res
|
||||
|
||||
|
||||
def _merge_lifespan_context(
|
||||
original_context: Lifespan[Any], nested_context: Lifespan[Any]
|
||||
) -> Lifespan[Any]:
|
||||
@asynccontextmanager
|
||||
async def merged_lifespan(
|
||||
app: AppType,
|
||||
) -> AsyncIterator[Optional[Mapping[str, Any]]]:
|
||||
async with original_context(app) as maybe_original_state:
|
||||
async with nested_context(app) as maybe_nested_state:
|
||||
if maybe_nested_state is None and maybe_original_state is None:
|
||||
yield None # old ASGI compatibility
|
||||
else:
|
||||
yield {**(maybe_nested_state or {}), **(maybe_original_state or {})}
|
||||
|
||||
return merged_lifespan # type: ignore[return-value]
|
||||
|
||||
|
||||
async def serialize_response(
|
||||
*,
|
||||
field: Optional[ModelField] = None,
|
||||
@@ -206,6 +227,7 @@ def get_request_handler(
|
||||
response_model_exclude_defaults: bool = False,
|
||||
response_model_exclude_none: bool = False,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
embed_body_fields: bool = False,
|
||||
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||
assert dependant.call is not None, "dependant.call must be a function"
|
||||
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
|
||||
@@ -272,27 +294,36 @@ def get_request_handler(
|
||||
body=body,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
async_exit_stack=async_exit_stack,
|
||||
embed_body_fields=embed_body_fields,
|
||||
)
|
||||
values, errors, background_tasks, sub_response, _ = solved_result
|
||||
errors = solved_result.errors
|
||||
if not errors:
|
||||
raw_response = await run_endpoint_function(
|
||||
dependant=dependant, values=values, is_coroutine=is_coroutine
|
||||
dependant=dependant,
|
||||
values=solved_result.values,
|
||||
is_coroutine=is_coroutine,
|
||||
)
|
||||
if isinstance(raw_response, Response):
|
||||
if raw_response.background is None:
|
||||
raw_response.background = background_tasks
|
||||
raw_response.background = solved_result.background_tasks
|
||||
response = raw_response
|
||||
else:
|
||||
response_args: Dict[str, Any] = {"background": background_tasks}
|
||||
response_args: Dict[str, Any] = {
|
||||
"background": solved_result.background_tasks
|
||||
}
|
||||
# If status_code was set, use it, otherwise use the default from the
|
||||
# response class, in the case of redirect it's 307
|
||||
current_status_code = (
|
||||
status_code if status_code else sub_response.status_code
|
||||
status_code
|
||||
if status_code
|
||||
else solved_result.response.status_code
|
||||
)
|
||||
if current_status_code is not None:
|
||||
response_args["status_code"] = current_status_code
|
||||
if sub_response.status_code:
|
||||
response_args["status_code"] = sub_response.status_code
|
||||
if solved_result.response.status_code:
|
||||
response_args["status_code"] = (
|
||||
solved_result.response.status_code
|
||||
)
|
||||
content = await serialize_response(
|
||||
field=response_field,
|
||||
response_content=raw_response,
|
||||
@@ -307,7 +338,7 @@ def get_request_handler(
|
||||
response = actual_response_class(content, **response_args)
|
||||
if not is_body_allowed_for_status_code(response.status_code):
|
||||
response.body = b""
|
||||
response.headers.raw.extend(sub_response.headers.raw)
|
||||
response.headers.raw.extend(solved_result.response.headers.raw)
|
||||
if errors:
|
||||
validation_error = RequestValidationError(
|
||||
_normalize_errors(errors), body=body
|
||||
@@ -327,7 +358,9 @@ def get_request_handler(
|
||||
|
||||
|
||||
def get_websocket_app(
|
||||
dependant: Dependant, dependency_overrides_provider: Optional[Any] = None
|
||||
dependant: Dependant,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
embed_body_fields: bool = False,
|
||||
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
|
||||
async def app(websocket: WebSocket) -> None:
|
||||
async with AsyncExitStack() as async_exit_stack:
|
||||
@@ -340,12 +373,14 @@ def get_websocket_app(
|
||||
dependant=dependant,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
async_exit_stack=async_exit_stack,
|
||||
embed_body_fields=embed_body_fields,
|
||||
)
|
||||
values, errors, _, _2, _3 = solved_result
|
||||
if errors:
|
||||
raise WebSocketRequestValidationError(_normalize_errors(errors))
|
||||
if solved_result.errors:
|
||||
raise WebSocketRequestValidationError(
|
||||
_normalize_errors(solved_result.errors)
|
||||
)
|
||||
assert dependant.call is not None, "dependant.call must be a function"
|
||||
await dependant.call(**values)
|
||||
await dependant.call(**solved_result.values)
|
||||
|
||||
return app
|
||||
|
||||
@@ -371,11 +406,15 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
||||
0,
|
||||
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
||||
)
|
||||
|
||||
self._flat_dependant = get_flat_dependant(self.dependant)
|
||||
self._embed_body_fields = _should_embed_body_fields(
|
||||
self._flat_dependant.body_params
|
||||
)
|
||||
self.app = websocket_session(
|
||||
get_websocket_app(
|
||||
dependant=self.dependant,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
embed_body_fields=self._embed_body_fields,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -454,9 +493,9 @@ class APIRoute(routing.Route):
|
||||
methods = ["GET"]
|
||||
self.methods: Set[str] = {method.upper() for method in methods}
|
||||
if isinstance(generate_unique_id_function, DefaultPlaceholder):
|
||||
current_generate_unique_id: Callable[
|
||||
["APIRoute"], str
|
||||
] = generate_unique_id_function.value
|
||||
current_generate_unique_id: Callable[[APIRoute], str] = (
|
||||
generate_unique_id_function.value
|
||||
)
|
||||
else:
|
||||
current_generate_unique_id = generate_unique_id_function
|
||||
self.unique_id = self.operation_id or current_generate_unique_id(self)
|
||||
@@ -469,7 +508,7 @@ class APIRoute(routing.Route):
|
||||
status_code
|
||||
), f"Status code {status_code} must not have a response body"
|
||||
response_name = "Response_" + self.unique_id
|
||||
self.response_field = create_response_field(
|
||||
self.response_field = create_model_field(
|
||||
name=response_name,
|
||||
type_=self.response_model,
|
||||
mode="serialization",
|
||||
@@ -482,9 +521,9 @@ class APIRoute(routing.Route):
|
||||
# By being a new field, no inheritance will be passed as is. A new model
|
||||
# will always be created.
|
||||
# TODO: remove when deprecating Pydantic v1
|
||||
self.secure_cloned_response_field: Optional[
|
||||
ModelField
|
||||
] = create_cloned_field(self.response_field)
|
||||
self.secure_cloned_response_field: Optional[ModelField] = (
|
||||
create_cloned_field(self.response_field)
|
||||
)
|
||||
else:
|
||||
self.response_field = None # type: ignore
|
||||
self.secure_cloned_response_field = None
|
||||
@@ -502,7 +541,9 @@ class APIRoute(routing.Route):
|
||||
additional_status_code
|
||||
), f"Status code {additional_status_code} must not have a response body"
|
||||
response_name = f"Response_{additional_status_code}_{self.unique_id}"
|
||||
response_field = create_response_field(name=response_name, type_=model)
|
||||
response_field = create_model_field(
|
||||
name=response_name, type_=model, mode="serialization"
|
||||
)
|
||||
response_fields[additional_status_code] = response_field
|
||||
if response_fields:
|
||||
self.response_fields: Dict[Union[int, str], ModelField] = response_fields
|
||||
@@ -516,7 +557,15 @@ class APIRoute(routing.Route):
|
||||
0,
|
||||
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
||||
)
|
||||
self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
|
||||
self._flat_dependant = get_flat_dependant(self.dependant)
|
||||
self._embed_body_fields = _should_embed_body_fields(
|
||||
self._flat_dependant.body_params
|
||||
)
|
||||
self.body_field = get_body_field(
|
||||
flat_dependant=self._flat_dependant,
|
||||
name=self.unique_id,
|
||||
embed_body_fields=self._embed_body_fields,
|
||||
)
|
||||
self.app = request_response(self.get_route_handler())
|
||||
|
||||
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||
@@ -533,6 +582,7 @@ class APIRoute(routing.Route):
|
||||
response_model_exclude_defaults=self.response_model_exclude_defaults,
|
||||
response_model_exclude_none=self.response_model_exclude_none,
|
||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||
embed_body_fields=self._embed_body_fields,
|
||||
)
|
||||
|
||||
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
|
||||
@@ -1308,6 +1358,10 @@ class APIRouter(routing.Router):
|
||||
self.add_event_handler("startup", handler)
|
||||
for handler in router.on_shutdown:
|
||||
self.add_event_handler("shutdown", handler)
|
||||
self.lifespan_context = _merge_lifespan_context(
|
||||
self.lifespan_context,
|
||||
router.lifespan_context,
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
|
@@ -5,11 +5,19 @@ from fastapi.security.base import SecurityBase
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated, Doc
|
||||
|
||||
|
||||
class APIKeyBase(SecurityBase):
|
||||
pass
|
||||
@staticmethod
|
||||
def check_api_key(api_key: Optional[str], auto_error: bool) -> Optional[str]:
|
||||
if not api_key:
|
||||
if auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
return None
|
||||
return api_key
|
||||
|
||||
|
||||
class APIKeyQuery(APIKeyBase):
|
||||
@@ -76,7 +84,7 @@ class APIKeyQuery(APIKeyBase):
|
||||
Doc(
|
||||
"""
|
||||
By default, if the query parameter is not provided, `APIKeyQuery` will
|
||||
automatically cancel the request and sebd the client an error.
|
||||
automatically cancel the request and send the client an error.
|
||||
|
||||
If `auto_error` is set to `False`, when the query parameter is not
|
||||
available, instead of erroring out, the dependency result will be
|
||||
@@ -101,14 +109,7 @@ class APIKeyQuery(APIKeyBase):
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
api_key = request.query_params.get(self.model.name)
|
||||
if not api_key:
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return api_key
|
||||
return self.check_api_key(api_key, self.auto_error)
|
||||
|
||||
|
||||
class APIKeyHeader(APIKeyBase):
|
||||
@@ -196,14 +197,7 @@ class APIKeyHeader(APIKeyBase):
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
api_key = request.headers.get(self.model.name)
|
||||
if not api_key:
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return api_key
|
||||
return self.check_api_key(api_key, self.auto_error)
|
||||
|
||||
|
||||
class APIKeyCookie(APIKeyBase):
|
||||
@@ -291,11 +285,4 @@ class APIKeyCookie(APIKeyBase):
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
api_key = request.cookies.get(self.model.name)
|
||||
if not api_key:
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return api_key
|
||||
return self.check_api_key(api_key, self.auto_error)
|
||||
|
@@ -10,12 +10,12 @@ from fastapi.security.utils import get_authorization_scheme_param
|
||||
from pydantic import BaseModel
|
||||
from starlette.requests import Request
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated, Doc
|
||||
|
||||
|
||||
class HTTPBasicCredentials(BaseModel):
|
||||
"""
|
||||
The HTTP Basic credendials given as the result of using `HTTPBasic` in a
|
||||
The HTTP Basic credentials given as the result of using `HTTPBasic` in a
|
||||
dependency.
|
||||
|
||||
Read more about it in the
|
||||
@@ -277,7 +277,7 @@ class HTTPBearer(HTTPBase):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if the HTTP Bearer token not provided (in an
|
||||
By default, if the HTTP Bearer token is not provided (in an
|
||||
`Authorization` header), `HTTPBearer` will automatically cancel the
|
||||
request and send the client an error.
|
||||
|
||||
@@ -380,7 +380,7 @@ class HTTPDigest(HTTPBase):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if the HTTP Digest not provided, `HTTPDigest` will
|
||||
By default, if the HTTP Digest is not provided, `HTTPDigest` will
|
||||
automatically cancel the request and send the client an error.
|
||||
|
||||
If `auto_error` is set to `False`, when the HTTP Digest is not
|
||||
|
@@ -10,7 +10,7 @@ from starlette.requests import Request
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
|
||||
|
||||
# TODO: import from typing when deprecating Python 3.9
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated, Doc
|
||||
|
||||
|
||||
class OAuth2PasswordRequestForm:
|
||||
@@ -52,9 +52,9 @@ class OAuth2PasswordRequestForm:
|
||||
```
|
||||
|
||||
Note that for OAuth2 the scope `items:read` is a single scope in an opaque string.
|
||||
You could have custom internal logic to separate it by colon caracters (`:`) or
|
||||
You could have custom internal logic to separate it by colon characters (`:`) or
|
||||
similar, and get the two parts `items` and `read`. Many applications do that to
|
||||
group and organize permisions, you could do it as well in your application, just
|
||||
group and organize permissions, you could do it as well in your application, just
|
||||
know that that it is application specific, it's not part of the specification.
|
||||
"""
|
||||
|
||||
@@ -63,7 +63,7 @@ class OAuth2PasswordRequestForm:
|
||||
*,
|
||||
grant_type: Annotated[
|
||||
Union[str, None],
|
||||
Form(pattern="password"),
|
||||
Form(pattern="^password$"),
|
||||
Doc(
|
||||
"""
|
||||
The OAuth2 spec says it is required and MUST be the fixed string
|
||||
@@ -194,9 +194,9 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
|
||||
```
|
||||
|
||||
Note that for OAuth2 the scope `items:read` is a single scope in an opaque string.
|
||||
You could have custom internal logic to separate it by colon caracters (`:`) or
|
||||
You could have custom internal logic to separate it by colon characters (`:`) or
|
||||
similar, and get the two parts `items` and `read`. Many applications do that to
|
||||
group and organize permisions, you could do it as well in your application, just
|
||||
group and organize permissions, you could do it as well in your application, just
|
||||
know that that it is application specific, it's not part of the specification.
|
||||
|
||||
|
||||
@@ -217,7 +217,7 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
|
||||
self,
|
||||
grant_type: Annotated[
|
||||
str,
|
||||
Form(pattern="password"),
|
||||
Form(pattern="^password$"),
|
||||
Doc(
|
||||
"""
|
||||
The OAuth2 spec says it is required and MUST be the fixed string
|
||||
@@ -441,7 +441,7 @@ class OAuth2PasswordBearer(OAuth2):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if no HTTP Auhtorization header is provided, required for
|
||||
By default, if no HTTP Authorization header is provided, required for
|
||||
OAuth2 authentication, it will automatically cancel the request and
|
||||
send the client an error.
|
||||
|
||||
@@ -543,7 +543,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if no HTTP Auhtorization header is provided, required for
|
||||
By default, if no HTTP Authorization header is provided, required for
|
||||
OAuth2 authentication, it will automatically cancel the request and
|
||||
send the client an error.
|
||||
|
||||
|
@@ -5,7 +5,7 @@ from fastapi.security.base import SecurityBase
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from typing_extensions import Annotated, Doc
|
||||
|
||||
|
||||
class OpenIdConnect(SecurityBase):
|
||||
@@ -49,7 +49,7 @@ class OpenIdConnect(SecurityBase):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if no HTTP Auhtorization header is provided, required for
|
||||
By default, if no HTTP Authorization header is provided, required for
|
||||
OpenID Connect authentication, it will automatically cancel the request
|
||||
and send the client an error.
|
||||
|
||||
|
@@ -34,9 +34,9 @@ if TYPE_CHECKING: # pragma: nocover
|
||||
from .routing import APIRoute
|
||||
|
||||
# Cache for `create_cloned_field`
|
||||
_CLONED_TYPES_CACHE: MutableMapping[
|
||||
Type[BaseModel], Type[BaseModel]
|
||||
] = WeakKeyDictionary()
|
||||
_CLONED_TYPES_CACHE: MutableMapping[Type[BaseModel], Type[BaseModel]] = (
|
||||
WeakKeyDictionary()
|
||||
)
|
||||
|
||||
|
||||
def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
|
||||
@@ -60,9 +60,9 @@ def get_path_param_names(path: str) -> Set[str]:
|
||||
return set(re.findall("{(.*?)}", path))
|
||||
|
||||
|
||||
def create_response_field(
|
||||
def create_model_field(
|
||||
name: str,
|
||||
type_: Type[Any],
|
||||
type_: Any,
|
||||
class_validators: Optional[Dict[str, Validator]] = None,
|
||||
default: Optional[Any] = Undefined,
|
||||
required: Union[bool, UndefinedType] = Undefined,
|
||||
@@ -71,9 +71,6 @@ def create_response_field(
|
||||
alias: Optional[str] = None,
|
||||
mode: Literal["validation", "serialization"] = "validation",
|
||||
) -> ModelField:
|
||||
"""
|
||||
Create a new response field. Raises if type_ is invalid.
|
||||
"""
|
||||
class_validators = class_validators or {}
|
||||
if PYDANTIC_V2:
|
||||
field_info = field_info or FieldInfo(
|
||||
@@ -135,7 +132,7 @@ def create_cloned_field(
|
||||
use_type.__fields__[f.name] = create_cloned_field(
|
||||
f, cloned_types=cloned_types
|
||||
)
|
||||
new_field = create_response_field(name=field.name, type_=use_type)
|
||||
new_field = create_model_field(name=field.name, type_=use_type)
|
||||
new_field.has_alias = field.has_alias # type: ignore[attr-defined]
|
||||
new_field.alias = field.alias # type: ignore[misc]
|
||||
new_field.class_validators = field.class_validators # type: ignore[attr-defined]
|
||||
@@ -221,9 +218,3 @@ def get_value_or_default(
|
||||
if not isinstance(item, DefaultPlaceholder):
|
||||
return item
|
||||
return first_item
|
||||
|
||||
|
||||
def match_pydantic_error_url(error_type: str) -> Any:
|
||||
from dirty_equals import IsStr
|
||||
|
||||
return IsStr(regex=rf"^https://errors\.pydantic\.dev/.*/v/{error_type}")
|
||||
|
Reference in New Issue
Block a user