mirror of
https://gitlab.com/MoonTestUse1/AdministrationItDepartmens.git
synced 2025-08-14 00:25:46 +02:00
Проверка 09.02.2025
This commit is contained in:
@@ -1,58 +1,37 @@
|
||||
from typing import Any, Callable, List, Optional, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple
|
||||
|
||||
from fastapi._compat import ModelField
|
||||
from fastapi.security.base import SecurityBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityRequirement:
|
||||
def __init__(
|
||||
self, security_scheme: SecurityBase, scopes: Optional[Sequence[str]] = None
|
||||
):
|
||||
self.security_scheme = security_scheme
|
||||
self.scopes = scopes
|
||||
security_scheme: SecurityBase
|
||||
scopes: Optional[Sequence[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Dependant:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
path_params: Optional[List[ModelField]] = None,
|
||||
query_params: Optional[List[ModelField]] = None,
|
||||
header_params: Optional[List[ModelField]] = None,
|
||||
cookie_params: Optional[List[ModelField]] = None,
|
||||
body_params: Optional[List[ModelField]] = None,
|
||||
dependencies: Optional[List["Dependant"]] = None,
|
||||
security_schemes: Optional[List[SecurityRequirement]] = None,
|
||||
name: Optional[str] = None,
|
||||
call: Optional[Callable[..., Any]] = None,
|
||||
request_param_name: Optional[str] = None,
|
||||
websocket_param_name: Optional[str] = None,
|
||||
http_connection_param_name: Optional[str] = None,
|
||||
response_param_name: Optional[str] = None,
|
||||
background_tasks_param_name: Optional[str] = None,
|
||||
security_scopes_param_name: Optional[str] = None,
|
||||
security_scopes: Optional[List[str]] = None,
|
||||
use_cache: bool = True,
|
||||
path: Optional[str] = None,
|
||||
) -> None:
|
||||
self.path_params = path_params or []
|
||||
self.query_params = query_params or []
|
||||
self.header_params = header_params or []
|
||||
self.cookie_params = cookie_params or []
|
||||
self.body_params = body_params or []
|
||||
self.dependencies = dependencies or []
|
||||
self.security_requirements = security_schemes or []
|
||||
self.request_param_name = request_param_name
|
||||
self.websocket_param_name = websocket_param_name
|
||||
self.http_connection_param_name = http_connection_param_name
|
||||
self.response_param_name = response_param_name
|
||||
self.background_tasks_param_name = background_tasks_param_name
|
||||
self.security_scopes = security_scopes
|
||||
self.security_scopes_param_name = security_scopes_param_name
|
||||
self.name = name
|
||||
self.call = call
|
||||
self.use_cache = use_cache
|
||||
# Store the path to be able to re-generate a dependable from it in overrides
|
||||
self.path = path
|
||||
# Save the cache key at creation to optimize performance
|
||||
path_params: List[ModelField] = field(default_factory=list)
|
||||
query_params: List[ModelField] = field(default_factory=list)
|
||||
header_params: List[ModelField] = field(default_factory=list)
|
||||
cookie_params: List[ModelField] = field(default_factory=list)
|
||||
body_params: List[ModelField] = field(default_factory=list)
|
||||
dependencies: List["Dependant"] = field(default_factory=list)
|
||||
security_requirements: List[SecurityRequirement] = field(default_factory=list)
|
||||
name: Optional[str] = None
|
||||
call: Optional[Callable[..., Any]] = None
|
||||
request_param_name: Optional[str] = None
|
||||
websocket_param_name: Optional[str] = None
|
||||
http_connection_param_name: Optional[str] = None
|
||||
response_param_name: Optional[str] = None
|
||||
background_tasks_param_name: Optional[str] = None
|
||||
security_scopes_param_name: Optional[str] = None
|
||||
security_scopes: Optional[List[str]] = None
|
||||
use_cache: bool = True
|
||||
path: Optional[str] = None
|
||||
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import inspect
|
||||
from contextlib import AsyncExitStack, contextmanager
|
||||
from copy import deepcopy
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@@ -23,7 +24,7 @@ from fastapi._compat import (
|
||||
PYDANTIC_V2,
|
||||
ErrorWrapper,
|
||||
ModelField,
|
||||
Required,
|
||||
RequiredParam,
|
||||
Undefined,
|
||||
_regenerate_error_with_loc,
|
||||
copy_field_info,
|
||||
@@ -31,6 +32,7 @@ from fastapi._compat import (
|
||||
evaluate_forwardref,
|
||||
field_annotation_is_scalar,
|
||||
get_annotation_from_field_info,
|
||||
get_cached_model_fields,
|
||||
get_missing_field_error,
|
||||
is_bytes_field,
|
||||
is_bytes_sequence_field,
|
||||
@@ -54,11 +56,18 @@ from fastapi.logger import logger
|
||||
from fastapi.security.base import SecurityBase
|
||||
from fastapi.security.oauth2 import OAuth2, SecurityScopes
|
||||
from fastapi.security.open_id_connect_url import OpenIdConnect
|
||||
from fastapi.utils import create_response_field, get_path_param_names
|
||||
from fastapi.utils import create_model_field, get_path_param_names
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
|
||||
from starlette.datastructures import (
|
||||
FormData,
|
||||
Headers,
|
||||
ImmutableMultiDict,
|
||||
QueryParams,
|
||||
UploadFile,
|
||||
)
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette.responses import Response
|
||||
from starlette.websockets import WebSocket
|
||||
@@ -79,17 +88,23 @@ multipart_incorrect_install_error = (
|
||||
)
|
||||
|
||||
|
||||
def check_file_field(field: ModelField) -> None:
|
||||
field_info = field.field_info
|
||||
if isinstance(field_info, params.Form):
|
||||
def ensure_multipart_is_installed() -> None:
|
||||
try:
|
||||
from python_multipart import __version__
|
||||
|
||||
# Import an attribute that can be mocked/deleted in testing
|
||||
assert __version__ > "0.0.12"
|
||||
except (ImportError, AssertionError):
|
||||
try:
|
||||
# __version__ is available in both multiparts, and can be mocked
|
||||
from multipart import __version__ # type: ignore
|
||||
from multipart import __version__ # type: ignore[no-redef,import-untyped]
|
||||
|
||||
assert __version__
|
||||
try:
|
||||
# parse_options_header is only available in the right multipart
|
||||
from multipart.multipart import parse_options_header # type: ignore
|
||||
from multipart.multipart import ( # type: ignore[import-untyped]
|
||||
parse_options_header,
|
||||
)
|
||||
|
||||
assert parse_options_header
|
||||
except ImportError:
|
||||
@@ -175,7 +190,7 @@ def get_flat_dependant(
|
||||
header_params=dependant.header_params.copy(),
|
||||
cookie_params=dependant.cookie_params.copy(),
|
||||
body_params=dependant.body_params.copy(),
|
||||
security_schemes=dependant.security_requirements.copy(),
|
||||
security_requirements=dependant.security_requirements.copy(),
|
||||
use_cache=dependant.use_cache,
|
||||
path=dependant.path,
|
||||
)
|
||||
@@ -194,14 +209,23 @@ def get_flat_dependant(
|
||||
return flat_dependant
|
||||
|
||||
|
||||
def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]:
|
||||
if not fields:
|
||||
return fields
|
||||
first_field = fields[0]
|
||||
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
|
||||
fields_to_extract = get_cached_model_fields(first_field.type_)
|
||||
return fields_to_extract
|
||||
return fields
|
||||
|
||||
|
||||
def get_flat_params(dependant: Dependant) -> List[ModelField]:
|
||||
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
|
||||
return (
|
||||
flat_dependant.path_params
|
||||
+ flat_dependant.query_params
|
||||
+ flat_dependant.header_params
|
||||
+ flat_dependant.cookie_params
|
||||
)
|
||||
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
|
||||
query_params = _get_flat_fields_from_params(flat_dependant.query_params)
|
||||
header_params = _get_flat_fields_from_params(flat_dependant.header_params)
|
||||
cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
|
||||
return path_params + query_params + header_params + cookie_params
|
||||
|
||||
|
||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
@@ -258,16 +282,16 @@ def get_dependant(
|
||||
)
|
||||
for param_name, param in signature_params.items():
|
||||
is_path_param = param_name in path_param_names
|
||||
type_annotation, depends, param_field = analyze_param(
|
||||
param_details = analyze_param(
|
||||
param_name=param_name,
|
||||
annotation=param.annotation,
|
||||
value=param.default,
|
||||
is_path_param=is_path_param,
|
||||
)
|
||||
if depends is not None:
|
||||
if param_details.depends is not None:
|
||||
sub_dependant = get_param_sub_dependant(
|
||||
param_name=param_name,
|
||||
depends=depends,
|
||||
depends=param_details.depends,
|
||||
path=path,
|
||||
security_scopes=security_scopes,
|
||||
)
|
||||
@@ -275,18 +299,18 @@ def get_dependant(
|
||||
continue
|
||||
if add_non_field_param_to_dependency(
|
||||
param_name=param_name,
|
||||
type_annotation=type_annotation,
|
||||
type_annotation=param_details.type_annotation,
|
||||
dependant=dependant,
|
||||
):
|
||||
assert (
|
||||
param_field is None
|
||||
param_details.field is None
|
||||
), f"Cannot specify multiple FastAPI annotations for {param_name!r}"
|
||||
continue
|
||||
assert param_field is not None
|
||||
if is_body_param(param_field=param_field, is_path_param=is_path_param):
|
||||
dependant.body_params.append(param_field)
|
||||
assert param_details.field is not None
|
||||
if isinstance(param_details.field.field_info, params.Body):
|
||||
dependant.body_params.append(param_details.field)
|
||||
else:
|
||||
add_param_to_fields(field=param_field, dependant=dependant)
|
||||
add_param_to_fields(field=param_details.field, dependant=dependant)
|
||||
return dependant
|
||||
|
||||
|
||||
@@ -314,13 +338,20 @@ def add_non_field_param_to_dependency(
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamDetails:
|
||||
type_annotation: Any
|
||||
depends: Optional[params.Depends]
|
||||
field: Optional[ModelField]
|
||||
|
||||
|
||||
def analyze_param(
|
||||
*,
|
||||
param_name: str,
|
||||
annotation: Any,
|
||||
value: Any,
|
||||
is_path_param: bool,
|
||||
) -> Tuple[Any, Optional[params.Depends], Optional[ModelField]]:
|
||||
) -> ParamDetails:
|
||||
field_info = None
|
||||
depends = None
|
||||
type_annotation: Any = Any
|
||||
@@ -328,6 +359,7 @@ def analyze_param(
|
||||
if annotation is not inspect.Signature.empty:
|
||||
use_annotation = annotation
|
||||
type_annotation = annotation
|
||||
# Extract Annotated info
|
||||
if get_origin(use_annotation) is Annotated:
|
||||
annotated_args = get_args(annotation)
|
||||
type_annotation = annotated_args[0]
|
||||
@@ -342,17 +374,20 @@ def analyze_param(
|
||||
if isinstance(arg, (params.Param, params.Body, params.Depends))
|
||||
]
|
||||
if fastapi_specific_annotations:
|
||||
fastapi_annotation: Union[
|
||||
FieldInfo, params.Depends, None
|
||||
] = fastapi_specific_annotations[-1]
|
||||
fastapi_annotation: Union[FieldInfo, params.Depends, None] = (
|
||||
fastapi_specific_annotations[-1]
|
||||
)
|
||||
else:
|
||||
fastapi_annotation = None
|
||||
# Set default for Annotated FieldInfo
|
||||
if isinstance(fastapi_annotation, FieldInfo):
|
||||
# Copy `field_info` because we mutate `field_info.default` below.
|
||||
field_info = copy_field_info(
|
||||
field_info=fastapi_annotation, annotation=use_annotation
|
||||
)
|
||||
assert field_info.default is Undefined or field_info.default is Required, (
|
||||
assert (
|
||||
field_info.default is Undefined or field_info.default is RequiredParam
|
||||
), (
|
||||
f"`{field_info.__class__.__name__}` default value cannot be set in"
|
||||
f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
|
||||
)
|
||||
@@ -360,10 +395,11 @@ def analyze_param(
|
||||
assert not is_path_param, "Path parameters cannot have default values"
|
||||
field_info.default = value
|
||||
else:
|
||||
field_info.default = Required
|
||||
field_info.default = RequiredParam
|
||||
# Get Annotated Depends
|
||||
elif isinstance(fastapi_annotation, params.Depends):
|
||||
depends = fastapi_annotation
|
||||
|
||||
# Get Depends from default value
|
||||
if isinstance(value, params.Depends):
|
||||
assert depends is None, (
|
||||
"Cannot specify `Depends` in `Annotated` and default value"
|
||||
@@ -374,6 +410,7 @@ def analyze_param(
|
||||
f" default value together for {param_name!r}"
|
||||
)
|
||||
depends = value
|
||||
# Get FieldInfo from default value
|
||||
elif isinstance(value, FieldInfo):
|
||||
assert field_info is None, (
|
||||
"Cannot specify FastAPI annotations in `Annotated` and default value"
|
||||
@@ -383,9 +420,13 @@ def analyze_param(
|
||||
if PYDANTIC_V2:
|
||||
field_info.annotation = type_annotation
|
||||
|
||||
# Get Depends from type annotation
|
||||
if depends is not None and depends.dependency is None:
|
||||
# Copy `depends` before mutating it
|
||||
depends = copy(depends)
|
||||
depends.dependency = type_annotation
|
||||
|
||||
# Handle non-param type annotations like Request
|
||||
if lenient_issubclass(
|
||||
type_annotation,
|
||||
(
|
||||
@@ -401,10 +442,11 @@ def analyze_param(
|
||||
assert (
|
||||
field_info is None
|
||||
), f"Cannot specify FastAPI annotation for type {type_annotation!r}"
|
||||
# Handle default assignations, neither field_info nor depends was not found in Annotated nor default value
|
||||
elif field_info is None and depends is None:
|
||||
default_value = value if value is not inspect.Signature.empty else Required
|
||||
default_value = value if value is not inspect.Signature.empty else RequiredParam
|
||||
if is_path_param:
|
||||
# We might check here that `default_value is Required`, but the fact is that the same
|
||||
# We might check here that `default_value is RequiredParam`, but the fact is that the same
|
||||
# parameter might sometimes be a path parameter and sometimes not. See
|
||||
# `tests/test_infer_param_optionality.py` for an example.
|
||||
field_info = params.Path(annotation=use_annotation)
|
||||
@@ -418,7 +460,9 @@ def analyze_param(
|
||||
field_info = params.Query(annotation=use_annotation, default=default_value)
|
||||
|
||||
field = None
|
||||
# It's a field_info, not a dependency
|
||||
if field_info is not None:
|
||||
# Handle field_info.in_
|
||||
if is_path_param:
|
||||
assert isinstance(field_info, params.Path), (
|
||||
f"Cannot use `{field_info.__class__.__name__}` for path param"
|
||||
@@ -434,40 +478,37 @@ def analyze_param(
|
||||
field_info,
|
||||
param_name,
|
||||
)
|
||||
if isinstance(field_info, params.Form):
|
||||
ensure_multipart_is_installed()
|
||||
if not field_info.alias and getattr(field_info, "convert_underscores", None):
|
||||
alias = param_name.replace("_", "-")
|
||||
else:
|
||||
alias = field_info.alias or param_name
|
||||
field_info.alias = alias
|
||||
field = create_response_field(
|
||||
field = create_model_field(
|
||||
name=param_name,
|
||||
type_=use_annotation_from_field_info,
|
||||
default=field_info.default,
|
||||
alias=alias,
|
||||
required=field_info.default in (Required, Undefined),
|
||||
required=field_info.default in (RequiredParam, Undefined),
|
||||
field_info=field_info,
|
||||
)
|
||||
if is_path_param:
|
||||
assert is_scalar_field(
|
||||
field=field
|
||||
), "Path params must be of one of the supported types"
|
||||
elif isinstance(field_info, params.Query):
|
||||
assert (
|
||||
is_scalar_field(field)
|
||||
or is_scalar_sequence_field(field)
|
||||
or (
|
||||
lenient_issubclass(field.type_, BaseModel)
|
||||
# For Pydantic v1
|
||||
and getattr(field, "shape", 1) == 1
|
||||
)
|
||||
)
|
||||
|
||||
return type_annotation, depends, field
|
||||
|
||||
|
||||
def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
|
||||
if is_path_param:
|
||||
assert is_scalar_field(
|
||||
field=param_field
|
||||
), "Path params must be of one of the supported types"
|
||||
return False
|
||||
elif is_scalar_field(field=param_field):
|
||||
return False
|
||||
elif isinstance(
|
||||
param_field.field_info, (params.Query, params.Header)
|
||||
) and is_scalar_sequence_field(param_field):
|
||||
return False
|
||||
else:
|
||||
assert isinstance(
|
||||
param_field.field_info, params.Body
|
||||
), f"Param: {param_field.name} can only be a request body, using Body()"
|
||||
return True
|
||||
return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
|
||||
|
||||
|
||||
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
|
||||
@@ -519,6 +560,15 @@ async def solve_generator(
|
||||
return await stack.enter_async_context(cm)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SolvedDependency:
|
||||
values: Dict[str, Any]
|
||||
errors: List[Any]
|
||||
background_tasks: Optional[StarletteBackgroundTasks]
|
||||
response: Response
|
||||
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
|
||||
|
||||
|
||||
async def solve_dependencies(
|
||||
*,
|
||||
request: Union[Request, WebSocket],
|
||||
@@ -529,13 +579,8 @@ async def solve_dependencies(
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
|
||||
async_exit_stack: AsyncExitStack,
|
||||
) -> Tuple[
|
||||
Dict[str, Any],
|
||||
List[Any],
|
||||
Optional[StarletteBackgroundTasks],
|
||||
Response,
|
||||
Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
|
||||
]:
|
||||
embed_body_fields: bool,
|
||||
) -> SolvedDependency:
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[Any] = []
|
||||
if response is None:
|
||||
@@ -576,28 +621,23 @@ async def solve_dependencies(
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
dependency_cache=dependency_cache,
|
||||
async_exit_stack=async_exit_stack,
|
||||
embed_body_fields=embed_body_fields,
|
||||
)
|
||||
(
|
||||
sub_values,
|
||||
sub_errors,
|
||||
background_tasks,
|
||||
_, # the subdependency returns the same response we have
|
||||
sub_dependency_cache,
|
||||
) = solved_result
|
||||
dependency_cache.update(sub_dependency_cache)
|
||||
if sub_errors:
|
||||
errors.extend(sub_errors)
|
||||
background_tasks = solved_result.background_tasks
|
||||
dependency_cache.update(solved_result.dependency_cache)
|
||||
if solved_result.errors:
|
||||
errors.extend(solved_result.errors)
|
||||
continue
|
||||
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
||||
solved = dependency_cache[sub_dependant.cache_key]
|
||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||
solved = await solve_generator(
|
||||
call=call, stack=async_exit_stack, sub_values=sub_values
|
||||
call=call, stack=async_exit_stack, sub_values=solved_result.values
|
||||
)
|
||||
elif is_coroutine_callable(call):
|
||||
solved = await call(**sub_values)
|
||||
solved = await call(**solved_result.values)
|
||||
else:
|
||||
solved = await run_in_threadpool(call, **sub_values)
|
||||
solved = await run_in_threadpool(call, **solved_result.values)
|
||||
if sub_dependant.name is not None:
|
||||
values[sub_dependant.name] = solved
|
||||
if sub_dependant.cache_key not in dependency_cache:
|
||||
@@ -624,7 +664,9 @@ async def solve_dependencies(
|
||||
body_values,
|
||||
body_errors,
|
||||
) = await request_body_to_args( # body_params checked above
|
||||
required_params=dependant.body_params, received_body=body
|
||||
body_fields=dependant.body_params,
|
||||
received_body=body,
|
||||
embed_body_fields=embed_body_fields,
|
||||
)
|
||||
values.update(body_values)
|
||||
errors.extend(body_errors)
|
||||
@@ -644,142 +686,257 @@ async def solve_dependencies(
|
||||
values[dependant.security_scopes_param_name] = SecurityScopes(
|
||||
scopes=dependant.security_scopes
|
||||
)
|
||||
return values, errors, background_tasks, response, dependency_cache
|
||||
return SolvedDependency(
|
||||
values=values,
|
||||
errors=errors,
|
||||
background_tasks=background_tasks,
|
||||
response=response,
|
||||
dependency_cache=dependency_cache,
|
||||
)
|
||||
|
||||
|
||||
def _validate_value_with_model_field(
|
||||
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
|
||||
) -> Tuple[Any, List[Any]]:
|
||||
if value is None:
|
||||
if field.required:
|
||||
return None, [get_missing_field_error(loc=loc)]
|
||||
else:
|
||||
return deepcopy(field.default), []
|
||||
v_, errors_ = field.validate(value, values, loc=loc)
|
||||
if isinstance(errors_, ErrorWrapper):
|
||||
return None, [errors_]
|
||||
elif isinstance(errors_, list):
|
||||
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
||||
return None, new_errors
|
||||
else:
|
||||
return v_, []
|
||||
|
||||
|
||||
def _get_multidict_value(
|
||||
field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
|
||||
) -> Any:
|
||||
alias = alias or field.alias
|
||||
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
|
||||
value = values.getlist(alias)
|
||||
else:
|
||||
value = values.get(alias, None)
|
||||
if (
|
||||
value is None
|
||||
or (
|
||||
isinstance(field.field_info, params.Form)
|
||||
and isinstance(value, str) # For type checks
|
||||
and value == ""
|
||||
)
|
||||
or (is_sequence_field(field) and len(value) == 0)
|
||||
):
|
||||
if field.required:
|
||||
return
|
||||
else:
|
||||
return deepcopy(field.default)
|
||||
return value
|
||||
|
||||
|
||||
def request_params_to_args(
|
||||
required_params: Sequence[ModelField],
|
||||
fields: Sequence[ModelField],
|
||||
received_params: Union[Mapping[str, Any], QueryParams, Headers],
|
||||
) -> Tuple[Dict[str, Any], List[Any]]:
|
||||
values = {}
|
||||
errors = []
|
||||
for field in required_params:
|
||||
if is_scalar_sequence_field(field) and isinstance(
|
||||
received_params, (QueryParams, Headers)
|
||||
):
|
||||
value = received_params.getlist(field.alias) or field.default
|
||||
else:
|
||||
value = received_params.get(field.alias)
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[Dict[str, Any]] = []
|
||||
|
||||
if not fields:
|
||||
return values, errors
|
||||
|
||||
first_field = fields[0]
|
||||
fields_to_extract = fields
|
||||
single_not_embedded_field = False
|
||||
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
|
||||
fields_to_extract = get_cached_model_fields(first_field.type_)
|
||||
single_not_embedded_field = True
|
||||
|
||||
params_to_process: Dict[str, Any] = {}
|
||||
|
||||
processed_keys = set()
|
||||
|
||||
for field in fields_to_extract:
|
||||
alias = None
|
||||
if isinstance(received_params, Headers):
|
||||
# Handle fields extracted from a Pydantic Model for a header, each field
|
||||
# doesn't have a FieldInfo of type Header with the default convert_underscores=True
|
||||
convert_underscores = getattr(field.field_info, "convert_underscores", True)
|
||||
if convert_underscores:
|
||||
alias = (
|
||||
field.alias
|
||||
if field.alias != field.name
|
||||
else field.name.replace("_", "-")
|
||||
)
|
||||
value = _get_multidict_value(field, received_params, alias=alias)
|
||||
if value is not None:
|
||||
params_to_process[field.name] = value
|
||||
processed_keys.add(alias or field.alias)
|
||||
processed_keys.add(field.name)
|
||||
|
||||
for key, value in received_params.items():
|
||||
if key not in processed_keys:
|
||||
params_to_process[key] = value
|
||||
|
||||
if single_not_embedded_field:
|
||||
field_info = first_field.field_info
|
||||
assert isinstance(
|
||||
field_info, params.Param
|
||||
), "Params must be subclasses of Param"
|
||||
loc: Tuple[str, ...] = (field_info.in_.value,)
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=first_field, value=params_to_process, values=values, loc=loc
|
||||
)
|
||||
return {first_field.name: v_}, errors_
|
||||
|
||||
for field in fields:
|
||||
value = _get_multidict_value(field, received_params)
|
||||
field_info = field.field_info
|
||||
assert isinstance(
|
||||
field_info, params.Param
|
||||
), "Params must be subclasses of Param"
|
||||
loc = (field_info.in_.value, field.alias)
|
||||
if value is None:
|
||||
if field.required:
|
||||
errors.append(get_missing_field_error(loc=loc))
|
||||
else:
|
||||
values[field.name] = deepcopy(field.default)
|
||||
continue
|
||||
v_, errors_ = field.validate(value, values, loc=loc)
|
||||
if isinstance(errors_, ErrorWrapper):
|
||||
errors.append(errors_)
|
||||
elif isinstance(errors_, list):
|
||||
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
||||
errors.extend(new_errors)
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=field, value=value, values=values, loc=loc
|
||||
)
|
||||
if errors_:
|
||||
errors.extend(errors_)
|
||||
else:
|
||||
values[field.name] = v_
|
||||
return values, errors
|
||||
|
||||
|
||||
async def request_body_to_args(
|
||||
required_params: List[ModelField],
|
||||
received_body: Optional[Union[Dict[str, Any], FormData]],
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
def _should_embed_body_fields(fields: List[ModelField]) -> bool:
|
||||
if not fields:
|
||||
return False
|
||||
# More than one dependency could have the same field, it would show up as multiple
|
||||
# fields but it's the same one, so count them by name
|
||||
body_param_names_set = {field.name for field in fields}
|
||||
# A top level field has to be a single field, not multiple
|
||||
if len(body_param_names_set) > 1:
|
||||
return True
|
||||
first_field = fields[0]
|
||||
# If it explicitly specifies it is embedded, it has to be embedded
|
||||
if getattr(first_field.field_info, "embed", None):
|
||||
return True
|
||||
# If it's a Form (or File) field, it has to be a BaseModel to be top level
|
||||
# otherwise it has to be embedded, so that the key value pair can be extracted
|
||||
if isinstance(first_field.field_info, params.Form) and not lenient_issubclass(
|
||||
first_field.type_, BaseModel
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _extract_form_body(
|
||||
body_fields: List[ModelField],
|
||||
received_body: FormData,
|
||||
) -> Dict[str, Any]:
|
||||
values = {}
|
||||
first_field = body_fields[0]
|
||||
first_field_info = first_field.field_info
|
||||
|
||||
for field in body_fields:
|
||||
value = _get_multidict_value(field, received_body)
|
||||
if (
|
||||
isinstance(first_field_info, params.File)
|
||||
and is_bytes_field(field)
|
||||
and isinstance(value, UploadFile)
|
||||
):
|
||||
value = await value.read()
|
||||
elif (
|
||||
is_bytes_sequence_field(field)
|
||||
and isinstance(first_field_info, params.File)
|
||||
and value_is_sequence(value)
|
||||
):
|
||||
# For types
|
||||
assert isinstance(value, sequence_types) # type: ignore[arg-type]
|
||||
results: List[Union[bytes, str]] = []
|
||||
|
||||
async def process_fn(
|
||||
fn: Callable[[], Coroutine[Any, Any, Any]],
|
||||
) -> None:
|
||||
result = await fn()
|
||||
results.append(result) # noqa: B023
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for sub_value in value:
|
||||
tg.start_soon(process_fn, sub_value.read)
|
||||
value = serialize_sequence_value(field=field, value=results)
|
||||
if value is not None:
|
||||
values[field.alias] = value
|
||||
for key, value in received_body.items():
|
||||
if key not in values:
|
||||
values[key] = value
|
||||
return values
|
||||
|
||||
|
||||
async def request_body_to_args(
|
||||
body_fields: List[ModelField],
|
||||
received_body: Optional[Union[Dict[str, Any], FormData]],
|
||||
embed_body_fields: bool,
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[Dict[str, Any]] = []
|
||||
if required_params:
|
||||
field = required_params[0]
|
||||
field_info = field.field_info
|
||||
embed = getattr(field_info, "embed", None)
|
||||
field_alias_omitted = len(required_params) == 1 and not embed
|
||||
if field_alias_omitted:
|
||||
received_body = {field.alias: received_body}
|
||||
assert body_fields, "request_body_to_args() should be called with fields"
|
||||
single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
|
||||
first_field = body_fields[0]
|
||||
body_to_process = received_body
|
||||
|
||||
for field in required_params:
|
||||
loc: Tuple[str, ...]
|
||||
if field_alias_omitted:
|
||||
loc = ("body",)
|
||||
else:
|
||||
loc = ("body", field.alias)
|
||||
fields_to_extract: List[ModelField] = body_fields
|
||||
|
||||
value: Optional[Any] = None
|
||||
if received_body is not None:
|
||||
if (is_sequence_field(field)) and isinstance(received_body, FormData):
|
||||
value = received_body.getlist(field.alias)
|
||||
else:
|
||||
try:
|
||||
value = received_body.get(field.alias)
|
||||
except AttributeError:
|
||||
errors.append(get_missing_field_error(loc))
|
||||
continue
|
||||
if (
|
||||
value is None
|
||||
or (isinstance(field_info, params.Form) and value == "")
|
||||
or (
|
||||
isinstance(field_info, params.Form)
|
||||
and is_sequence_field(field)
|
||||
and len(value) == 0
|
||||
)
|
||||
):
|
||||
if field.required:
|
||||
errors.append(get_missing_field_error(loc))
|
||||
else:
|
||||
values[field.name] = deepcopy(field.default)
|
||||
if single_not_embedded_field and lenient_issubclass(first_field.type_, BaseModel):
|
||||
fields_to_extract = get_cached_model_fields(first_field.type_)
|
||||
|
||||
if isinstance(received_body, FormData):
|
||||
body_to_process = await _extract_form_body(fields_to_extract, received_body)
|
||||
|
||||
if single_not_embedded_field:
|
||||
loc: Tuple[str, ...] = ("body",)
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=first_field, value=body_to_process, values=values, loc=loc
|
||||
)
|
||||
return {first_field.name: v_}, errors_
|
||||
for field in body_fields:
|
||||
loc = ("body", field.alias)
|
||||
value: Optional[Any] = None
|
||||
if body_to_process is not None:
|
||||
try:
|
||||
value = body_to_process.get(field.alias)
|
||||
# If the received body is a list, not a dict
|
||||
except AttributeError:
|
||||
errors.append(get_missing_field_error(loc))
|
||||
continue
|
||||
if (
|
||||
isinstance(field_info, params.File)
|
||||
and is_bytes_field(field)
|
||||
and isinstance(value, UploadFile)
|
||||
):
|
||||
value = await value.read()
|
||||
elif (
|
||||
is_bytes_sequence_field(field)
|
||||
and isinstance(field_info, params.File)
|
||||
and value_is_sequence(value)
|
||||
):
|
||||
# For types
|
||||
assert isinstance(value, sequence_types) # type: ignore[arg-type]
|
||||
results: List[Union[bytes, str]] = []
|
||||
|
||||
async def process_fn(
|
||||
fn: Callable[[], Coroutine[Any, Any, Any]]
|
||||
) -> None:
|
||||
result = await fn()
|
||||
results.append(result) # noqa: B023
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for sub_value in value:
|
||||
tg.start_soon(process_fn, sub_value.read)
|
||||
value = serialize_sequence_value(field=field, value=results)
|
||||
|
||||
v_, errors_ = field.validate(value, values, loc=loc)
|
||||
|
||||
if isinstance(errors_, list):
|
||||
errors.extend(errors_)
|
||||
elif errors_:
|
||||
errors.append(errors_)
|
||||
else:
|
||||
values[field.name] = v_
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=field, value=value, values=values, loc=loc
|
||||
)
|
||||
if errors_:
|
||||
errors.extend(errors_)
|
||||
else:
|
||||
values[field.name] = v_
|
||||
return values, errors
|
||||
|
||||
|
||||
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
|
||||
flat_dependant = get_flat_dependant(dependant)
|
||||
def get_body_field(
|
||||
*, flat_dependant: Dependant, name: str, embed_body_fields: bool
|
||||
) -> Optional[ModelField]:
|
||||
"""
|
||||
Get a ModelField representing the request body for a path operation, combining
|
||||
all body parameters into a single field if necessary.
|
||||
|
||||
Used to check if it's form data (with `isinstance(body_field, params.Form)`)
|
||||
or JSON and to generate the JSON Schema for a request body.
|
||||
|
||||
This is **not** used to validate/parse the request body, that's done with each
|
||||
individual body parameter.
|
||||
"""
|
||||
if not flat_dependant.body_params:
|
||||
return None
|
||||
first_param = flat_dependant.body_params[0]
|
||||
field_info = first_param.field_info
|
||||
embed = getattr(field_info, "embed", None)
|
||||
body_param_names_set = {param.name for param in flat_dependant.body_params}
|
||||
if len(body_param_names_set) == 1 and not embed:
|
||||
check_file_field(first_param)
|
||||
if not embed_body_fields:
|
||||
return first_param
|
||||
# If one field requires to embed, all have to be embedded
|
||||
# in case a sub-dependency is evaluated with a single unique body field
|
||||
# That is combined (embedded) with other body fields
|
||||
for param in flat_dependant.body_params:
|
||||
setattr(param.field_info, "embed", True) # noqa: B010
|
||||
model_name = "Body_" + name
|
||||
BodyModel = create_body_model(
|
||||
fields=flat_dependant.body_params, model_name=model_name
|
||||
@@ -805,12 +962,11 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
|
||||
]
|
||||
if len(set(body_param_media_types)) == 1:
|
||||
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
|
||||
final_field = create_response_field(
|
||||
final_field = create_model_field(
|
||||
name="body",
|
||||
type_=BodyModel,
|
||||
required=required,
|
||||
alias="body",
|
||||
field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
|
||||
)
|
||||
check_file_field(final_field)
|
||||
return final_field
|
||||
|
Reference in New Issue
Block a user