mirror of
https://gitlab.com/MoonTestUse1/AdministrationItDepartmens.git
synced 2025-08-14 00:25:46 +02:00
Проверка 09.02.2025
This commit is contained in:
@@ -1 +1 @@
|
||||
__version__ = "0.36.3"
|
||||
__version__ = "0.45.3"
|
||||
|
@@ -1,28 +0,0 @@
|
||||
import hashlib
|
||||
|
||||
# Compat wrapper to always include the `usedforsecurity=...` parameter,
|
||||
# which is only added from Python 3.9 onwards.
|
||||
# We use this flag to indicate that we use `md5` hashes only for non-security
|
||||
# cases (our ETag checksums).
|
||||
# If we don't indicate that we're using MD5 for non-security related reasons,
|
||||
# then attempting to use this function will raise an error when used
|
||||
# environments which enable a strict "FIPs mode".
|
||||
#
|
||||
# See issue: https://github.com/encode/starlette/issues/1365
|
||||
try:
|
||||
# check if the Python version supports the parameter
|
||||
# using usedforsecurity=False to avoid an exception on FIPS systems
|
||||
# that reject usedforsecurity=True
|
||||
hashlib.md5(b"data", usedforsecurity=False) # type: ignore[call-arg]
|
||||
|
||||
def md5_hexdigest(
|
||||
data: bytes, *, usedforsecurity: bool = True
|
||||
) -> str: # pragma: no cover
|
||||
return hashlib.md5( # type: ignore[call-arg]
|
||||
data, usedforsecurity=usedforsecurity
|
||||
).hexdigest()
|
||||
|
||||
except TypeError: # pragma: no cover
|
||||
|
||||
def md5_hexdigest(data: bytes, *, usedforsecurity: bool = True) -> str:
|
||||
return hashlib.md5(data).hexdigest()
|
@@ -6,25 +6,14 @@ from starlette._utils import is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.types import (
|
||||
ASGIApp,
|
||||
ExceptionHandler,
|
||||
HTTPExceptionHandler,
|
||||
Message,
|
||||
Receive,
|
||||
Scope,
|
||||
Send,
|
||||
WebSocketExceptionHandler,
|
||||
)
|
||||
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
|
||||
StatusHandlers = typing.Dict[int, ExceptionHandler]
|
||||
ExceptionHandlers = dict[typing.Any, ExceptionHandler]
|
||||
StatusHandlers = dict[int, ExceptionHandler]
|
||||
|
||||
|
||||
def _lookup_exception_handler(
|
||||
exc_handlers: ExceptionHandlers, exc: Exception
|
||||
) -> ExceptionHandler | None:
|
||||
def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -> ExceptionHandler | None:
|
||||
for cls in type(exc).__mro__:
|
||||
if cls in exc_handlers:
|
||||
return exc_handlers[cls]
|
||||
@@ -64,24 +53,13 @@ def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASG
|
||||
raise exc
|
||||
|
||||
if response_started:
|
||||
msg = "Caught handled exception, but response already started."
|
||||
raise RuntimeError(msg) from exc
|
||||
raise RuntimeError("Caught handled exception, but response already started.") from exc
|
||||
|
||||
if scope["type"] == "http":
|
||||
nonlocal conn
|
||||
handler = typing.cast(HTTPExceptionHandler, handler)
|
||||
conn = typing.cast(Request, conn)
|
||||
if is_async_callable(handler):
|
||||
response = await handler(conn, exc)
|
||||
else:
|
||||
response = await run_in_threadpool(handler, conn, exc)
|
||||
if is_async_callable(handler):
|
||||
response = await handler(conn, exc)
|
||||
else:
|
||||
response = await run_in_threadpool(handler, conn, exc) # type: ignore
|
||||
if response is not None:
|
||||
await response(scope, receive, sender)
|
||||
elif scope["type"] == "websocket":
|
||||
handler = typing.cast(WebSocketExceptionHandler, handler)
|
||||
conn = typing.cast(WebSocket, conn)
|
||||
if is_async_callable(handler):
|
||||
await handler(conn, exc)
|
||||
else:
|
||||
await run_in_threadpool(handler, conn, exc)
|
||||
|
||||
return wrapped_app
|
||||
|
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import re
|
||||
import sys
|
||||
import typing
|
||||
from contextlib import contextmanager
|
||||
@@ -17,7 +16,7 @@ else: # pragma: no cover
|
||||
has_exceptiongroups = True
|
||||
if sys.version_info < (3, 11): # pragma: no cover
|
||||
try:
|
||||
from exceptiongroup import BaseExceptionGroup
|
||||
from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found]
|
||||
except ImportError:
|
||||
has_exceptiongroups = False
|
||||
|
||||
@@ -26,41 +25,31 @@ AwaitableCallable = typing.Callable[..., typing.Awaitable[T]]
|
||||
|
||||
|
||||
@typing.overload
|
||||
def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]:
|
||||
...
|
||||
def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ...
|
||||
|
||||
|
||||
@typing.overload
|
||||
def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]:
|
||||
...
|
||||
def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]: ...
|
||||
|
||||
|
||||
def is_async_callable(obj: typing.Any) -> typing.Any:
|
||||
while isinstance(obj, functools.partial):
|
||||
obj = obj.func
|
||||
|
||||
return asyncio.iscoroutinefunction(obj) or (
|
||||
callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
|
||||
)
|
||||
return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__))
|
||||
|
||||
|
||||
T_co = typing.TypeVar("T_co", covariant=True)
|
||||
|
||||
|
||||
class AwaitableOrContextManager(
|
||||
typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]
|
||||
):
|
||||
...
|
||||
class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]): ...
|
||||
|
||||
|
||||
class SupportsAsyncClose(typing.Protocol):
|
||||
async def close(self) -> None:
|
||||
... # pragma: no cover
|
||||
async def close(self) -> None: ... # pragma: no cover
|
||||
|
||||
|
||||
SupportsAsyncCloseType = typing.TypeVar(
|
||||
"SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False
|
||||
)
|
||||
SupportsAsyncCloseType = typing.TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
|
||||
|
||||
|
||||
class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
|
||||
@@ -86,14 +75,26 @@ def collapse_excgroups() -> typing.Generator[None, None, None]:
|
||||
try:
|
||||
yield
|
||||
except BaseException as exc:
|
||||
if has_exceptiongroups:
|
||||
if has_exceptiongroups: # pragma: no cover
|
||||
while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
|
||||
exc = exc.exceptions[0] # pragma: no cover
|
||||
exc = exc.exceptions[0]
|
||||
|
||||
raise exc
|
||||
|
||||
|
||||
def get_route_path(scope: Scope) -> str:
|
||||
path: str = scope["path"]
|
||||
root_path = scope.get("root_path", "")
|
||||
route_path = re.sub(r"^" + root_path, "", scope["path"])
|
||||
return route_path
|
||||
if not root_path:
|
||||
return path
|
||||
|
||||
if not path.startswith(root_path):
|
||||
return path
|
||||
|
||||
if path == root_path:
|
||||
return ""
|
||||
|
||||
if path[len(root_path)] == "/":
|
||||
return path[len(root_path) :]
|
||||
|
||||
return path
|
||||
|
@@ -10,7 +10,7 @@ else: # pragma: no cover
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from starlette.datastructures import State, URLPath
|
||||
from starlette.middleware import Middleware, _MiddlewareClass
|
||||
from starlette.middleware import Middleware, _MiddlewareFactory
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.errors import ServerErrorMiddleware
|
||||
from starlette.middleware.exceptions import ExceptionMiddleware
|
||||
@@ -25,34 +25,7 @@ P = ParamSpec("P")
|
||||
|
||||
|
||||
class Starlette:
|
||||
"""
|
||||
Creates an application instance.
|
||||
|
||||
**Parameters:**
|
||||
|
||||
* **debug** - Boolean indicating if debug tracebacks should be returned on errors.
|
||||
* **routes** - A list of routes to serve incoming HTTP and WebSocket requests.
|
||||
* **middleware** - A list of middleware to run for every request. A starlette
|
||||
application will always automatically include two middleware classes.
|
||||
`ServerErrorMiddleware` is added as the very outermost middleware, to handle
|
||||
any uncaught errors occurring anywhere in the entire stack.
|
||||
`ExceptionMiddleware` is added as the very innermost middleware, to deal
|
||||
with handled exception cases occurring in the routing or endpoints.
|
||||
* **exception_handlers** - A mapping of either integer status codes,
|
||||
or exception class types onto callables which handle the exceptions.
|
||||
Exception handler callables should be of the form
|
||||
`handler(request, exc) -> response` and may be either standard functions, or
|
||||
async functions.
|
||||
* **on_startup** - A list of callables to run on application startup.
|
||||
Startup handler callables do not take any arguments, and may be either
|
||||
standard functions, or async functions.
|
||||
* **on_shutdown** - A list of callables to run on application shutdown.
|
||||
Shutdown handler callables do not take any arguments, and may be either
|
||||
standard functions, or async functions.
|
||||
* **lifespan** - A lifespan context function, which can be used to perform
|
||||
startup and shutdown tasks. This is a newer style that replaces the
|
||||
`on_startup` and `on_shutdown` handlers. Use one or the other, not both.
|
||||
"""
|
||||
"""Creates an Starlette application."""
|
||||
|
||||
def __init__(
|
||||
self: AppType,
|
||||
@@ -64,6 +37,32 @@ class Starlette:
|
||||
on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
|
||||
lifespan: Lifespan[AppType] | None = None,
|
||||
) -> None:
|
||||
"""Initializes the application.
|
||||
|
||||
Parameters:
|
||||
debug: Boolean indicating if debug tracebacks should be returned on errors.
|
||||
routes: A list of routes to serve incoming HTTP and WebSocket requests.
|
||||
middleware: A list of middleware to run for every request. A starlette
|
||||
application will always automatically include two middleware classes.
|
||||
`ServerErrorMiddleware` is added as the very outermost middleware, to handle
|
||||
any uncaught errors occurring anywhere in the entire stack.
|
||||
`ExceptionMiddleware` is added as the very innermost middleware, to deal
|
||||
with handled exception cases occurring in the routing or endpoints.
|
||||
exception_handlers: A mapping of either integer status codes,
|
||||
or exception class types onto callables which handle the exceptions.
|
||||
Exception handler callables should be of the form
|
||||
`handler(request, exc) -> response` and may be either standard functions, or
|
||||
async functions.
|
||||
on_startup: A list of callables to run on application startup.
|
||||
Startup handler callables do not take any arguments, and may be either
|
||||
standard functions, or async functions.
|
||||
on_shutdown: A list of callables to run on application shutdown.
|
||||
Shutdown handler callables do not take any arguments, and may be either
|
||||
standard functions, or async functions.
|
||||
lifespan: A lifespan context function, which can be used to perform
|
||||
startup and shutdown tasks. This is a newer style that replaces the
|
||||
`on_startup` and `on_shutdown` handlers. Use one or the other, not both.
|
||||
"""
|
||||
# The lifespan context function is a newer style that replaces
|
||||
# on_startup / on_shutdown handlers. Use one or the other, not both.
|
||||
assert lifespan is None or (
|
||||
@@ -72,21 +71,15 @@ class Starlette:
|
||||
|
||||
self.debug = debug
|
||||
self.state = State()
|
||||
self.router = Router(
|
||||
routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan
|
||||
)
|
||||
self.exception_handlers = (
|
||||
{} if exception_handlers is None else dict(exception_handlers)
|
||||
)
|
||||
self.router = Router(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan)
|
||||
self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers)
|
||||
self.user_middleware = [] if middleware is None else list(middleware)
|
||||
self.middleware_stack: typing.Optional[ASGIApp] = None
|
||||
self.middleware_stack: ASGIApp | None = None
|
||||
|
||||
def build_middleware_stack(self) -> ASGIApp:
|
||||
debug = self.debug
|
||||
error_handler = None
|
||||
exception_handlers: dict[
|
||||
typing.Any, typing.Callable[[Request, Exception], Response]
|
||||
] = {}
|
||||
exception_handlers: dict[typing.Any, typing.Callable[[Request, Exception], Response]] = {}
|
||||
|
||||
for key, value in self.exception_handlers.items():
|
||||
if key in (500, Exception):
|
||||
@@ -97,16 +90,12 @@ class Starlette:
|
||||
middleware = (
|
||||
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
|
||||
+ self.user_middleware
|
||||
+ [
|
||||
Middleware(
|
||||
ExceptionMiddleware, handlers=exception_handlers, debug=debug
|
||||
)
|
||||
]
|
||||
+ [Middleware(ExceptionMiddleware, handlers=exception_handlers, debug=debug)]
|
||||
)
|
||||
|
||||
app = self.router
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
app = cls(app=app, *args, **kwargs)
|
||||
app = cls(app, *args, **kwargs)
|
||||
return app
|
||||
|
||||
@property
|
||||
@@ -123,7 +112,7 @@ class Starlette:
|
||||
await self.middleware_stack(scope, receive, send)
|
||||
|
||||
def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
|
||||
return self.router.on_event(event_type) # pragma: nocover
|
||||
return self.router.on_event(event_type) # pragma: no cover
|
||||
|
||||
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:
|
||||
self.router.mount(path, app=app, name=name) # pragma: no cover
|
||||
@@ -133,7 +122,7 @@ class Starlette:
|
||||
|
||||
def add_middleware(
|
||||
self,
|
||||
middleware_class: typing.Type[_MiddlewareClass[P]],
|
||||
middleware_class: _MiddlewareFactory[P],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
@@ -143,7 +132,7 @@ class Starlette:
|
||||
|
||||
def add_exception_handler(
|
||||
self,
|
||||
exc_class_or_status_code: int | typing.Type[Exception],
|
||||
exc_class_or_status_code: int | type[Exception],
|
||||
handler: ExceptionHandler,
|
||||
) -> None: # pragma: no cover
|
||||
self.exception_handlers[exc_class_or_status_code] = handler
|
||||
@@ -159,13 +148,11 @@ class Starlette:
|
||||
self,
|
||||
path: str,
|
||||
route: typing.Callable[[Request], typing.Awaitable[Response] | Response],
|
||||
methods: typing.Optional[typing.List[str]] = None,
|
||||
name: typing.Optional[str] = None,
|
||||
methods: list[str] | None = None,
|
||||
name: str | None = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> None: # pragma: no cover
|
||||
self.router.add_route(
|
||||
path, route, methods=methods, name=name, include_in_schema=include_in_schema
|
||||
)
|
||||
self.router.add_route(path, route, methods=methods, name=name, include_in_schema=include_in_schema)
|
||||
|
||||
def add_websocket_route(
|
||||
self,
|
||||
@@ -175,16 +162,14 @@ class Starlette:
|
||||
) -> None: # pragma: no cover
|
||||
self.router.add_websocket_route(path, route, name=name)
|
||||
|
||||
def exception_handler(
|
||||
self, exc_class_or_status_code: int | typing.Type[Exception]
|
||||
) -> typing.Callable: # type: ignore[type-arg]
|
||||
def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> typing.Callable: # type: ignore[type-arg]
|
||||
warnings.warn(
|
||||
"The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
|
||||
"Refer to https://www.starlette.io/exceptions/ for the recommended approach.", # noqa: E501
|
||||
"The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. "
|
||||
"Refer to https://www.starlette.io/exceptions/ for the recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
|
||||
self.add_exception_handler(exc_class_or_status_code, func)
|
||||
return func
|
||||
|
||||
@@ -205,12 +190,12 @@ class Starlette:
|
||||
>>> app = Starlette(routes=routes)
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `route` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
|
||||
"Refer to https://www.starlette.io/routing/ for the recommended approach.", # noqa: E501
|
||||
"The `route` decorator is deprecated, and will be removed in version 1.0.0. "
|
||||
"Refer to https://www.starlette.io/routing/ for the recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
|
||||
self.router.add_route(
|
||||
path,
|
||||
func,
|
||||
@@ -231,18 +216,18 @@ class Starlette:
|
||||
>>> app = Starlette(routes=routes)
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
|
||||
"Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501
|
||||
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. "
|
||||
"Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
|
||||
self.router.add_websocket_route(path, func, name=name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
|
||||
def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg]
|
||||
"""
|
||||
We no longer document this decorator style API, and its usage is discouraged.
|
||||
Instead you should use the following approach:
|
||||
@@ -251,15 +236,13 @@ class Starlette:
|
||||
>>> app = Starlette(middleware=middleware)
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `middleware` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
|
||||
"Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.", # noqa: E501
|
||||
"The `middleware` decorator is deprecated, and will be removed in version 1.0.0. "
|
||||
"Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
assert (
|
||||
middleware_type == "http"
|
||||
), 'Currently only middleware("http") is supported.'
|
||||
assert middleware_type == "http", 'Currently only middleware("http") is supported.'
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
|
||||
self.add_middleware(BaseHTTPMiddleware, dispatch=func)
|
||||
return func
|
||||
|
||||
|
@@ -31,9 +31,7 @@ def requires(
|
||||
scopes: str | typing.Sequence[str],
|
||||
status_code: int = 403,
|
||||
redirect: str | None = None,
|
||||
) -> typing.Callable[
|
||||
[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]
|
||||
]:
|
||||
) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]:
|
||||
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
|
||||
|
||||
def decorator(
|
||||
@@ -45,17 +43,13 @@ def requires(
|
||||
type_ = parameter.name
|
||||
break
|
||||
else:
|
||||
raise Exception(
|
||||
f'No "request" or "websocket" argument on function "{func}"'
|
||||
)
|
||||
raise Exception(f'No "request" or "websocket" argument on function "{func}"')
|
||||
|
||||
if type_ == "websocket":
|
||||
# Handle websocket functions. (Always async)
|
||||
@functools.wraps(func)
|
||||
async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
|
||||
websocket = kwargs.get(
|
||||
"websocket", args[idx] if idx < len(args) else None
|
||||
)
|
||||
websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None)
|
||||
assert isinstance(websocket, WebSocket)
|
||||
|
||||
if not has_required_scope(websocket, scopes_list):
|
||||
@@ -75,10 +69,7 @@ def requires(
|
||||
if not has_required_scope(request, scopes_list):
|
||||
if redirect is not None:
|
||||
orig_request_qparam = urlencode({"next": str(request.url)})
|
||||
next_url = "{redirect_path}?{orig_request}".format(
|
||||
redirect_path=request.url_for(redirect),
|
||||
orig_request=orig_request_qparam,
|
||||
)
|
||||
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
|
||||
return RedirectResponse(url=next_url, status_code=303)
|
||||
raise HTTPException(status_code=status_code)
|
||||
return await func(*args, **kwargs)
|
||||
@@ -95,10 +86,7 @@ def requires(
|
||||
if not has_required_scope(request, scopes_list):
|
||||
if redirect is not None:
|
||||
orig_request_qparam = urlencode({"next": str(request.url)})
|
||||
next_url = "{redirect_path}?{orig_request}".format(
|
||||
redirect_path=request.url_for(redirect),
|
||||
orig_request=orig_request_qparam,
|
||||
)
|
||||
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
|
||||
return RedirectResponse(url=next_url, status_code=303)
|
||||
raise HTTPException(status_code=status_code)
|
||||
return func(*args, **kwargs)
|
||||
@@ -113,9 +101,7 @@ class AuthenticationError(Exception):
|
||||
|
||||
|
||||
class AuthenticationBackend:
|
||||
async def authenticate(
|
||||
self, conn: HTTPConnection
|
||||
) -> tuple[AuthCredentials, BaseUser] | None:
|
||||
async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
|
@@ -15,9 +15,7 @@ P = ParamSpec("P")
|
||||
|
||||
|
||||
class BackgroundTask:
|
||||
def __init__(
|
||||
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
|
||||
) -> None:
|
||||
def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
@@ -34,9 +32,7 @@ class BackgroundTasks(BackgroundTask):
|
||||
def __init__(self, tasks: typing.Sequence[BackgroundTask] | None = None):
|
||||
self.tasks = list(tasks) if tasks else []
|
||||
|
||||
def add_task(
|
||||
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
|
||||
) -> None:
|
||||
def add_task(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
|
||||
task = BackgroundTask(func, *args, **kwargs)
|
||||
self.tasks.append(task)
|
||||
|
||||
|
@@ -16,16 +16,15 @@ P = ParamSpec("P")
|
||||
T = typing.TypeVar("T")
|
||||
|
||||
|
||||
async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg] # noqa: E501
|
||||
async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg]
|
||||
warnings.warn(
|
||||
"run_until_first_complete is deprecated "
|
||||
"and will be removed in a future version.",
|
||||
"run_until_first_complete is deprecated and will be removed in a future version.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as task_group:
|
||||
|
||||
async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg] # noqa: E501
|
||||
async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg]
|
||||
await func()
|
||||
task_group.cancel_scope.cancel()
|
||||
|
||||
@@ -33,13 +32,9 @@ async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None:
|
||||
task_group.start_soon(run, functools.partial(func, **kwargs))
|
||||
|
||||
|
||||
async def run_in_threadpool(
|
||||
func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
|
||||
) -> T:
|
||||
if kwargs: # pragma: no cover
|
||||
# run_sync doesn't accept 'kwargs', so bind them in here
|
||||
func = functools.partial(func, **kwargs)
|
||||
return await anyio.to_thread.run_sync(func, *args)
|
||||
async def run_in_threadpool(func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
func = functools.partial(func, *args, **kwargs)
|
||||
return await anyio.to_thread.run_sync(func)
|
||||
|
||||
|
||||
class _StopIteration(Exception):
|
||||
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import typing
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -16,7 +17,7 @@ class EnvironError(Exception):
|
||||
class Environ(typing.MutableMapping[str, str]):
|
||||
def __init__(self, environ: typing.MutableMapping[str, str] = os.environ):
|
||||
self._environ = environ
|
||||
self._has_been_read: typing.Set[str] = set()
|
||||
self._has_been_read: set[str] = set()
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
self._has_been_read.add(key)
|
||||
@@ -24,18 +25,12 @@ class Environ(typing.MutableMapping[str, str]):
|
||||
|
||||
def __setitem__(self, key: str, value: str) -> None:
|
||||
if key in self._has_been_read:
|
||||
raise EnvironError(
|
||||
f"Attempting to set environ['{key}'], but the value has already been "
|
||||
"read."
|
||||
)
|
||||
raise EnvironError(f"Attempting to set environ['{key}'], but the value has already been read.")
|
||||
self._environ.__setitem__(key, value)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
if key in self._has_been_read:
|
||||
raise EnvironError(
|
||||
f"Attempting to delete environ['{key}'], but the value has already "
|
||||
"been read."
|
||||
)
|
||||
raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already been read.")
|
||||
self._environ.__delitem__(key)
|
||||
|
||||
def __iter__(self) -> typing.Iterator[str]:
|
||||
@@ -59,23 +54,21 @@ class Config:
|
||||
) -> None:
|
||||
self.environ = environ
|
||||
self.env_prefix = env_prefix
|
||||
self.file_values: typing.Dict[str, str] = {}
|
||||
self.file_values: dict[str, str] = {}
|
||||
if env_file is not None:
|
||||
if not os.path.isfile(env_file):
|
||||
raise FileNotFoundError(f"Config file '{env_file}' not found.")
|
||||
self.file_values = self._read_file(env_file)
|
||||
warnings.warn(f"Config file '{env_file}' not found.")
|
||||
else:
|
||||
self.file_values = self._read_file(env_file)
|
||||
|
||||
@typing.overload
|
||||
def __call__(self, key: str, *, default: None) -> str | None:
|
||||
...
|
||||
def __call__(self, key: str, *, default: None) -> str | None: ...
|
||||
|
||||
@typing.overload
|
||||
def __call__(self, key: str, cast: type[T], default: T = ...) -> T:
|
||||
...
|
||||
def __call__(self, key: str, cast: type[T], default: T = ...) -> T: ...
|
||||
|
||||
@typing.overload
|
||||
def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str:
|
||||
...
|
||||
def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str: ...
|
||||
|
||||
@typing.overload
|
||||
def __call__(
|
||||
@@ -83,12 +76,10 @@ class Config:
|
||||
key: str,
|
||||
cast: typing.Callable[[typing.Any], T] = ...,
|
||||
default: typing.Any = ...,
|
||||
) -> T:
|
||||
...
|
||||
) -> T: ...
|
||||
|
||||
@typing.overload
|
||||
def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str:
|
||||
...
|
||||
def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -116,7 +107,7 @@ class Config:
|
||||
raise KeyError(f"Config '{key}' is missing, and has no default.")
|
||||
|
||||
def _read_file(self, file_name: str | Path) -> dict[str, str]:
|
||||
file_values: typing.Dict[str, str] = {}
|
||||
file_values: dict[str, str] = {}
|
||||
with open(file_name) as input_file:
|
||||
for line in input_file.readlines():
|
||||
line = line.strip()
|
||||
@@ -139,13 +130,9 @@ class Config:
|
||||
mapping = {"true": True, "1": True, "false": False, "0": False}
|
||||
value = value.lower()
|
||||
if value not in mapping:
|
||||
raise ValueError(
|
||||
f"Config '{key}' has value '{value}'. Not a valid bool."
|
||||
)
|
||||
raise ValueError(f"Config '{key}' has value '{value}'. Not a valid bool.")
|
||||
return mapping[value]
|
||||
try:
|
||||
return cast(value)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}."
|
||||
)
|
||||
raise ValueError(f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}.")
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import typing
|
||||
import uuid
|
||||
@@ -65,7 +67,7 @@ class FloatConvertor(Convertor[float]):
|
||||
|
||||
|
||||
class UUIDConvertor(Convertor[uuid.UUID]):
|
||||
regex = "[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||
regex = "[0-9a-fA-F]{8}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{12}"
|
||||
|
||||
def convert(self, value: str) -> uuid.UUID:
|
||||
return uuid.UUID(value)
|
||||
@@ -74,7 +76,7 @@ class UUIDConvertor(Convertor[uuid.UUID]):
|
||||
return str(value)
|
||||
|
||||
|
||||
CONVERTOR_TYPES: typing.Dict[str, Convertor[typing.Any]] = {
|
||||
CONVERTOR_TYPES: dict[str, Convertor[typing.Any]] = {
|
||||
"str": StringConvertor(),
|
||||
"path": PathConvertor(),
|
||||
"int": IntegerConvertor(),
|
||||
|
@@ -108,12 +108,7 @@ class URL:
|
||||
return self.scheme in ("https", "wss")
|
||||
|
||||
def replace(self, **kwargs: typing.Any) -> URL:
|
||||
if (
|
||||
"username" in kwargs
|
||||
or "password" in kwargs
|
||||
or "hostname" in kwargs
|
||||
or "port" in kwargs
|
||||
):
|
||||
if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
|
||||
hostname = kwargs.pop("hostname", None)
|
||||
port = kwargs.pop("port", self.port)
|
||||
username = kwargs.pop("username", self.username)
|
||||
@@ -150,7 +145,7 @@ class URL:
|
||||
query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
|
||||
return self.replace(query=query)
|
||||
|
||||
def remove_query_params(self, keys: str | typing.Sequence[str]) -> "URL":
|
||||
def remove_query_params(self, keys: str | typing.Sequence[str]) -> URL:
|
||||
if isinstance(keys, str):
|
||||
keys = [keys]
|
||||
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
|
||||
@@ -178,7 +173,7 @@ class URLPath(str):
|
||||
Used by the routing to return `url_path_for` matches.
|
||||
"""
|
||||
|
||||
def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath":
|
||||
def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
|
||||
assert protocol in ("http", "websocket", "")
|
||||
return str.__new__(cls, path)
|
||||
|
||||
@@ -251,30 +246,25 @@ class CommaSeparatedStrings(typing.Sequence[str]):
|
||||
|
||||
|
||||
class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
|
||||
_dict: typing.Dict[_KeyType, _CovariantValueType]
|
||||
_dict: dict[_KeyType, _CovariantValueType]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: ImmutableMultiDict[_KeyType, _CovariantValueType]
|
||||
| typing.Mapping[_KeyType, _CovariantValueType]
|
||||
| typing.Iterable[typing.Tuple[_KeyType, _CovariantValueType]],
|
||||
| typing.Iterable[tuple[_KeyType, _CovariantValueType]],
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
assert len(args) < 2, "Too many arguments."
|
||||
|
||||
value: typing.Any = args[0] if args else []
|
||||
if kwargs:
|
||||
value = (
|
||||
ImmutableMultiDict(value).multi_items()
|
||||
+ ImmutableMultiDict(kwargs).multi_items()
|
||||
)
|
||||
value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
|
||||
|
||||
if not value:
|
||||
_items: list[tuple[typing.Any, typing.Any]] = []
|
||||
elif hasattr(value, "multi_items"):
|
||||
value = typing.cast(
|
||||
ImmutableMultiDict[_KeyType, _CovariantValueType], value
|
||||
)
|
||||
value = typing.cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
|
||||
_items = list(value.multi_items())
|
||||
elif hasattr(value, "items"):
|
||||
value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value)
|
||||
@@ -371,9 +361,7 @@ class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
|
||||
|
||||
def update(
|
||||
self,
|
||||
*args: MultiDict
|
||||
| typing.Mapping[typing.Any, typing.Any]
|
||||
| list[tuple[typing.Any, typing.Any]],
|
||||
*args: MultiDict | typing.Mapping[typing.Any, typing.Any] | list[tuple[typing.Any, typing.Any]],
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
value = MultiDict(*args, **kwargs)
|
||||
@@ -403,9 +391,7 @@ class QueryParams(ImmutableMultiDict[str, str]):
|
||||
if isinstance(value, str):
|
||||
super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
|
||||
elif isinstance(value, bytes):
|
||||
super().__init__(
|
||||
parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs
|
||||
)
|
||||
super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
|
||||
else:
|
||||
super().__init__(*args, **kwargs) # type: ignore[arg-type]
|
||||
self._list = [(str(k), str(v)) for k, v in self._list]
|
||||
@@ -490,9 +476,7 @@ class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: FormData
|
||||
| typing.Mapping[str, str | UploadFile]
|
||||
| list[tuple[str, str | UploadFile]],
|
||||
*args: FormData | typing.Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
|
||||
**kwargs: str | UploadFile,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -518,10 +502,7 @@ class Headers(typing.Mapping[str, str]):
|
||||
if headers is not None:
|
||||
assert raw is None, 'Cannot set both "headers" and "raw".'
|
||||
assert scope is None, 'Cannot set both "headers" and "scope".'
|
||||
self._list = [
|
||||
(key.lower().encode("latin-1"), value.encode("latin-1"))
|
||||
for key, value in headers.items()
|
||||
]
|
||||
self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
|
||||
elif raw is not None:
|
||||
assert scope is None, 'Cannot set both "raw" and "scope".'
|
||||
self._list = raw
|
||||
@@ -541,18 +522,11 @@ class Headers(typing.Mapping[str, str]):
|
||||
return [value.decode("latin-1") for key, value in self._list]
|
||||
|
||||
def items(self) -> list[tuple[str, str]]: # type: ignore[override]
|
||||
return [
|
||||
(key.decode("latin-1"), value.decode("latin-1"))
|
||||
for key, value in self._list
|
||||
]
|
||||
return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
|
||||
|
||||
def getlist(self, key: str) -> list[str]:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
return [
|
||||
item_value.decode("latin-1")
|
||||
for item_key, item_value in self._list
|
||||
if item_key == get_header_key
|
||||
]
|
||||
return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
|
||||
|
||||
def mutablecopy(self) -> MutableHeaders:
|
||||
return MutableHeaders(raw=self._list[:])
|
||||
@@ -599,7 +573,7 @@ class MutableHeaders(Headers):
|
||||
set_key = key.lower().encode("latin-1")
|
||||
set_value = value.encode("latin-1")
|
||||
|
||||
found_indexes: "typing.List[int]" = []
|
||||
found_indexes: list[int] = []
|
||||
for idx, (item_key, item_value) in enumerate(self._list):
|
||||
if item_key == set_key:
|
||||
found_indexes.append(idx)
|
||||
@@ -619,7 +593,7 @@ class MutableHeaders(Headers):
|
||||
"""
|
||||
del_key = key.lower().encode("latin-1")
|
||||
|
||||
pop_indexes: "typing.List[int]" = []
|
||||
pop_indexes: list[int] = []
|
||||
for idx, (item_key, item_value) in enumerate(self._list):
|
||||
if item_key == del_key:
|
||||
pop_indexes.append(idx)
|
||||
|
@@ -30,15 +30,9 @@ class HTTPEndpoint:
|
||||
|
||||
async def dispatch(self) -> None:
|
||||
request = Request(self.scope, receive=self.receive)
|
||||
handler_name = (
|
||||
"get"
|
||||
if request.method == "HEAD" and not hasattr(self, "head")
|
||||
else request.method.lower()
|
||||
)
|
||||
handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower()
|
||||
|
||||
handler: typing.Callable[[Request], typing.Any] = getattr(
|
||||
self, handler_name, self.method_not_allowed
|
||||
)
|
||||
handler: typing.Callable[[Request], typing.Any] = getattr(self, handler_name, self.method_not_allowed)
|
||||
is_async = is_async_callable(handler)
|
||||
if is_async:
|
||||
response = await handler(request)
|
||||
@@ -80,10 +74,8 @@ class WebSocketEndpoint:
|
||||
if message["type"] == "websocket.receive":
|
||||
data = await self.decode(websocket, message)
|
||||
await self.on_receive(websocket, data)
|
||||
elif message["type"] == "websocket.disconnect":
|
||||
close_code = int(
|
||||
message.get("code") or status.WS_1000_NORMAL_CLOSURE
|
||||
)
|
||||
elif message["type"] == "websocket.disconnect": # pragma: no branch
|
||||
close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE)
|
||||
break
|
||||
except Exception as exc:
|
||||
close_code = status.WS_1011_INTERNAL_ERROR
|
||||
@@ -116,9 +108,7 @@ class WebSocketEndpoint:
|
||||
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
|
||||
raise RuntimeError("Malformed JSON data received.")
|
||||
|
||||
assert (
|
||||
self.encoding is None
|
||||
), f"Unsupported 'encoding' attribute {self.encoding}"
|
||||
assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}"
|
||||
return message["text"] if message.get("text") else message["bytes"]
|
||||
|
||||
async def on_connect(self, websocket: WebSocket) -> None:
|
||||
|
@@ -1,19 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import http
|
||||
import typing
|
||||
import warnings
|
||||
|
||||
__all__ = ("HTTPException", "WebSocketException")
|
||||
from collections.abc import Mapping
|
||||
|
||||
|
||||
class HTTPException(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
detail: str | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
def __init__(self, status_code: int, detail: str | None = None, headers: Mapping[str, str] | None = None) -> None:
|
||||
if detail is None:
|
||||
detail = http.HTTPStatus(status_code).phrase
|
||||
self.status_code = status_code
|
||||
@@ -39,24 +31,3 @@ class WebSocketException(Exception):
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
return f"{class_name}(code={self.code!r}, reason={self.reason!r})"
|
||||
|
||||
|
||||
__deprecated__ = "ExceptionMiddleware"
|
||||
|
||||
|
||||
def __getattr__(name: str) -> typing.Any: # pragma: no cover
|
||||
if name == __deprecated__:
|
||||
from starlette.middleware.exceptions import ExceptionMiddleware
|
||||
|
||||
warnings.warn(
|
||||
f"{__deprecated__} is deprecated on `starlette.exceptions`. "
|
||||
f"Import it from `starlette.middleware.exceptions` instead.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
return ExceptionMiddleware
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(list(__all__) + [__deprecated__]) # pragma: no cover
|
||||
|
@@ -8,12 +8,20 @@ from urllib.parse import unquote_plus
|
||||
|
||||
from starlette.datastructures import FormData, Headers, UploadFile
|
||||
|
||||
try:
|
||||
import multipart
|
||||
from multipart.multipart import parse_options_header
|
||||
except ModuleNotFoundError: # pragma: nocover
|
||||
parse_options_header = None
|
||||
multipart = None
|
||||
if typing.TYPE_CHECKING:
|
||||
import python_multipart as multipart
|
||||
from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header
|
||||
else:
|
||||
try:
|
||||
try:
|
||||
import python_multipart as multipart
|
||||
from python_multipart.multipart import parse_options_header
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
import multipart
|
||||
from multipart.multipart import parse_options_header
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
multipart = None
|
||||
parse_options_header = None
|
||||
|
||||
|
||||
class FormMessage(Enum):
|
||||
@@ -28,12 +36,12 @@ class FormMessage(Enum):
|
||||
class MultipartPart:
|
||||
content_disposition: bytes | None = None
|
||||
field_name: str = ""
|
||||
data: bytes = b""
|
||||
data: bytearray = field(default_factory=bytearray)
|
||||
file: UploadFile | None = None
|
||||
item_headers: list[tuple[bytes, bytes]] = field(default_factory=list)
|
||||
|
||||
|
||||
def _user_safe_decode(src: bytes, codec: str) -> str:
|
||||
def _user_safe_decode(src: bytes | bytearray, codec: str) -> str:
|
||||
try:
|
||||
return src.decode(codec)
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
@@ -46,12 +54,8 @@ class MultiPartException(Exception):
|
||||
|
||||
|
||||
class FormParser:
|
||||
def __init__(
|
||||
self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
|
||||
) -> None:
|
||||
assert (
|
||||
multipart is not None
|
||||
), "The `python-multipart` library must be installed to use form parsing."
|
||||
def __init__(self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]) -> None:
|
||||
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
|
||||
self.headers = headers
|
||||
self.stream = stream
|
||||
self.messages: list[tuple[FormMessage, bytes]] = []
|
||||
@@ -78,7 +82,7 @@ class FormParser:
|
||||
|
||||
async def parse(self) -> FormData:
|
||||
# Callbacks dictionary.
|
||||
callbacks = {
|
||||
callbacks: QuerystringCallbacks = {
|
||||
"on_field_start": self.on_field_start,
|
||||
"on_field_name": self.on_field_name,
|
||||
"on_field_data": self.on_field_data,
|
||||
@@ -91,7 +95,7 @@ class FormParser:
|
||||
field_name = b""
|
||||
field_value = b""
|
||||
|
||||
items: list[tuple[str, typing.Union[str, UploadFile]]] = []
|
||||
items: list[tuple[str, str | UploadFile]] = []
|
||||
|
||||
# Feed the parser with data from the request.
|
||||
async for chunk in self.stream:
|
||||
@@ -118,7 +122,7 @@ class FormParser:
|
||||
|
||||
|
||||
class MultiPartParser:
|
||||
max_file_size = 1024 * 1024
|
||||
max_file_size = 1024 * 1024 # 1MB
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -127,10 +131,9 @@ class MultiPartParser:
|
||||
*,
|
||||
max_files: int | float = 1000,
|
||||
max_fields: int | float = 1000,
|
||||
max_part_size: int = 1024 * 1024, # 1MB
|
||||
) -> None:
|
||||
assert (
|
||||
multipart is not None
|
||||
), "The `python-multipart` library must be installed to use form parsing."
|
||||
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
|
||||
self.headers = headers
|
||||
self.stream = stream
|
||||
self.max_files = max_files
|
||||
@@ -145,6 +148,7 @@ class MultiPartParser:
|
||||
self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
|
||||
self._file_parts_to_finish: list[MultipartPart] = []
|
||||
self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
|
||||
self.max_part_size = max_part_size
|
||||
|
||||
def on_part_begin(self) -> None:
|
||||
self._current_part = MultipartPart()
|
||||
@@ -152,7 +156,9 @@ class MultiPartParser:
|
||||
def on_part_data(self, data: bytes, start: int, end: int) -> None:
|
||||
message_bytes = data[start:end]
|
||||
if self._current_part.file is None:
|
||||
self._current_part.data += message_bytes
|
||||
if len(self._current_part.data) + len(message_bytes) > self.max_part_size:
|
||||
raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.")
|
||||
self._current_part.data.extend(message_bytes)
|
||||
else:
|
||||
self._file_parts_to_write.append((self._current_part, message_bytes))
|
||||
|
||||
@@ -181,30 +187,20 @@ class MultiPartParser:
|
||||
field = self._current_partial_header_name.lower()
|
||||
if field == b"content-disposition":
|
||||
self._current_part.content_disposition = self._current_partial_header_value
|
||||
self._current_part.item_headers.append(
|
||||
(field, self._current_partial_header_value)
|
||||
)
|
||||
self._current_part.item_headers.append((field, self._current_partial_header_value))
|
||||
self._current_partial_header_name = b""
|
||||
self._current_partial_header_value = b""
|
||||
|
||||
def on_headers_finished(self) -> None:
|
||||
disposition, options = parse_options_header(
|
||||
self._current_part.content_disposition
|
||||
)
|
||||
disposition, options = parse_options_header(self._current_part.content_disposition)
|
||||
try:
|
||||
self._current_part.field_name = _user_safe_decode(
|
||||
options[b"name"], self._charset
|
||||
)
|
||||
self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset)
|
||||
except KeyError:
|
||||
raise MultiPartException(
|
||||
'The Content-Disposition header field "name" must be ' "provided."
|
||||
)
|
||||
raise MultiPartException('The Content-Disposition header field "name" must be provided.')
|
||||
if b"filename" in options:
|
||||
self._current_files += 1
|
||||
if self._current_files > self.max_files:
|
||||
raise MultiPartException(
|
||||
f"Too many files. Maximum number of files is {self.max_files}."
|
||||
)
|
||||
raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
|
||||
filename = _user_safe_decode(options[b"filename"], self._charset)
|
||||
tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
|
||||
self._files_to_close_on_error.append(tempfile)
|
||||
@@ -217,9 +213,7 @@ class MultiPartParser:
|
||||
else:
|
||||
self._current_fields += 1
|
||||
if self._current_fields > self.max_fields:
|
||||
raise MultiPartException(
|
||||
f"Too many fields. Maximum number of fields is {self.max_fields}."
|
||||
)
|
||||
raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.")
|
||||
self._current_part.file = None
|
||||
|
||||
def on_end(self) -> None:
|
||||
@@ -238,7 +232,7 @@ class MultiPartParser:
|
||||
raise MultiPartException("Missing boundary in multipart.")
|
||||
|
||||
# Callbacks dictionary.
|
||||
callbacks = {
|
||||
callbacks: MultipartCallbacks = {
|
||||
"on_part_begin": self.on_part_begin,
|
||||
"on_part_data": self.on_part_data,
|
||||
"on_part_end": self.on_part_end,
|
||||
|
@@ -1,30 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import Any, Iterator, Protocol
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, Protocol
|
||||
|
||||
if sys.version_info >= (3, 10): # pragma: no cover
|
||||
from typing import ParamSpec
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class _MiddlewareClass(Protocol[P]):
|
||||
def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None:
|
||||
... # pragma: no cover
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
... # pragma: no cover
|
||||
class _MiddlewareFactory(Protocol[P]):
|
||||
def __call__(self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover
|
||||
|
||||
|
||||
class Middleware:
|
||||
def __init__(
|
||||
self,
|
||||
cls: type[_MiddlewareClass[P]],
|
||||
cls: _MiddlewareFactory[P],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
@@ -40,5 +37,6 @@ class Middleware:
|
||||
class_name = self.__class__.__name__
|
||||
args_strings = [f"{value!r}" for value in self.args]
|
||||
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
|
||||
args_repr = ", ".join([self.cls.__name__] + args_strings + option_strings)
|
||||
name = getattr(self.cls, "__name__", "")
|
||||
args_repr = ", ".join([name] + args_strings + option_strings)
|
||||
return f"{class_name}({args_repr})"
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from starlette.authentication import (
|
||||
@@ -16,15 +18,13 @@ class AuthenticationMiddleware:
|
||||
self,
|
||||
app: ASGIApp,
|
||||
backend: AuthenticationBackend,
|
||||
on_error: typing.Optional[
|
||||
typing.Callable[[HTTPConnection, AuthenticationError], Response]
|
||||
] = None,
|
||||
on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.backend = backend
|
||||
self.on_error: typing.Callable[
|
||||
[HTTPConnection, AuthenticationError], Response
|
||||
] = on_error if on_error is not None else self.default_on_error
|
||||
self.on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] = (
|
||||
on_error if on_error is not None else self.default_on_error
|
||||
)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ["http", "websocket"]:
|
||||
|
@@ -1,18 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import anyio
|
||||
from anyio.abc import ObjectReceiveStream, ObjectSendStream
|
||||
|
||||
from starlette._utils import collapse_excgroups
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.requests import ClientDisconnect, Request
|
||||
from starlette.responses import ContentStream, Response, StreamingResponse
|
||||
from starlette.responses import AsyncContentStream, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
|
||||
DispatchFunction = typing.Callable[
|
||||
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
|
||||
]
|
||||
DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]]
|
||||
T = typing.TypeVar("T")
|
||||
|
||||
|
||||
@@ -54,6 +52,7 @@ class _CachedRequest(Request):
|
||||
# at this point a disconnect is all that we should be receiving
|
||||
# if we get something else, things went wrong somewhere
|
||||
raise RuntimeError(f"Unexpected message received: {msg['type']}")
|
||||
self._wrapped_rcv_disconnected = True
|
||||
return msg
|
||||
|
||||
# wrapped_rcv state 3: not yet consumed
|
||||
@@ -92,9 +91,7 @@ class _CachedRequest(Request):
|
||||
|
||||
|
||||
class BaseHTTPMiddleware:
|
||||
def __init__(
|
||||
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
|
||||
) -> None:
|
||||
def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None:
|
||||
self.app = app
|
||||
self.dispatch_func = self.dispatch if dispatch is None else dispatch
|
||||
|
||||
@@ -108,10 +105,7 @@ class BaseHTTPMiddleware:
|
||||
response_sent = anyio.Event()
|
||||
|
||||
async def call_next(request: Request) -> Response:
|
||||
app_exc: typing.Optional[Exception] = None
|
||||
send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
|
||||
recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
|
||||
send_stream, recv_stream = anyio.create_memory_object_stream()
|
||||
app_exc: Exception | None = None
|
||||
|
||||
async def receive_or_disconnect() -> Message:
|
||||
if response_sent.is_set():
|
||||
@@ -132,10 +126,6 @@ class BaseHTTPMiddleware:
|
||||
|
||||
return message
|
||||
|
||||
async def close_recv_stream_on_response_sent() -> None:
|
||||
await response_sent.wait()
|
||||
recv_stream.close()
|
||||
|
||||
async def send_no_error(message: Message) -> None:
|
||||
try:
|
||||
await send_stream.send(message)
|
||||
@@ -146,13 +136,12 @@ class BaseHTTPMiddleware:
|
||||
async def coro() -> None:
|
||||
nonlocal app_exc
|
||||
|
||||
async with send_stream:
|
||||
with send_stream:
|
||||
try:
|
||||
await self.app(scope, receive_or_disconnect, send_no_error)
|
||||
except Exception as exc:
|
||||
app_exc = exc
|
||||
|
||||
task_group.start_soon(close_recv_stream_on_response_sent)
|
||||
task_group.start_soon(coro)
|
||||
|
||||
try:
|
||||
@@ -168,50 +157,65 @@ class BaseHTTPMiddleware:
|
||||
assert message["type"] == "http.response.start"
|
||||
|
||||
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
|
||||
async with recv_stream:
|
||||
async for message in recv_stream:
|
||||
assert message["type"] == "http.response.body"
|
||||
body = message.get("body", b"")
|
||||
if body:
|
||||
yield body
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
async for message in recv_stream:
|
||||
assert message["type"] == "http.response.body"
|
||||
body = message.get("body", b"")
|
||||
if body:
|
||||
yield body
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
|
||||
if app_exc is not None:
|
||||
raise app_exc
|
||||
|
||||
response = _StreamingResponse(
|
||||
status_code=message["status"], content=body_stream(), info=info
|
||||
)
|
||||
response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
|
||||
response.raw_headers = message["headers"]
|
||||
return response
|
||||
|
||||
with collapse_excgroups():
|
||||
streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream()
|
||||
send_stream, recv_stream = streams
|
||||
with recv_stream, send_stream, collapse_excgroups():
|
||||
async with anyio.create_task_group() as task_group:
|
||||
response = await self.dispatch_func(request, call_next)
|
||||
await response(scope, wrapped_receive, send)
|
||||
response_sent.set()
|
||||
recv_stream.close()
|
||||
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class _StreamingResponse(StreamingResponse):
|
||||
class _StreamingResponse(Response):
|
||||
def __init__(
|
||||
self,
|
||||
content: ContentStream,
|
||||
content: AsyncContentStream,
|
||||
status_code: int = 200,
|
||||
headers: typing.Optional[typing.Mapping[str, str]] = None,
|
||||
media_type: typing.Optional[str] = None,
|
||||
background: typing.Optional[BackgroundTask] = None,
|
||||
info: typing.Optional[typing.Mapping[str, typing.Any]] = None,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
info: typing.Mapping[str, typing.Any] | None = None,
|
||||
) -> None:
|
||||
self._info = info
|
||||
super().__init__(content, status_code, headers, media_type, background)
|
||||
self.info = info
|
||||
self.body_iterator = content
|
||||
self.status_code = status_code
|
||||
self.media_type = media_type
|
||||
self.init_headers(headers)
|
||||
self.background = None
|
||||
|
||||
async def stream_response(self, send: Send) -> None:
|
||||
if self._info:
|
||||
await send({"type": "http.response.debug", "info": self._info})
|
||||
return await super().stream_response(send)
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if self.info is not None:
|
||||
await send({"type": "http.response.debug", "info": self.info})
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
|
||||
async for chunk in self.body_iterator:
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
||||
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
if self.background:
|
||||
await self.background()
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import re
|
||||
import typing
|
||||
@@ -18,7 +20,7 @@ class CORSMiddleware:
|
||||
allow_methods: typing.Sequence[str] = ("GET",),
|
||||
allow_headers: typing.Sequence[str] = (),
|
||||
allow_credentials: bool = False,
|
||||
allow_origin_regex: typing.Optional[str] = None,
|
||||
allow_origin_regex: str | None = None,
|
||||
expose_headers: typing.Sequence[str] = (),
|
||||
max_age: int = 600,
|
||||
) -> None:
|
||||
@@ -94,9 +96,7 @@ class CORSMiddleware:
|
||||
if self.allow_all_origins:
|
||||
return True
|
||||
|
||||
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(
|
||||
origin
|
||||
):
|
||||
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
|
||||
return True
|
||||
|
||||
return origin in self.allow_origins
|
||||
@@ -139,15 +139,11 @@ class CORSMiddleware:
|
||||
|
||||
return PlainTextResponse("OK", status_code=200, headers=headers)
|
||||
|
||||
async def simple_response(
|
||||
self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
|
||||
) -> None:
|
||||
async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None:
|
||||
send = functools.partial(self.send, send=send, request_headers=request_headers)
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
async def send(
|
||||
self, message: Message, send: Send, request_headers: Headers
|
||||
) -> None:
|
||||
async def send(self, message: Message, send: Send, request_headers: Headers) -> None:
|
||||
if message["type"] != "http.response.start":
|
||||
await send(message)
|
||||
return
|
||||
|
@@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
import inspect
|
||||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
@@ -137,9 +140,7 @@ class ServerErrorMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
handler: typing.Optional[
|
||||
typing.Callable[[Request, Exception], typing.Any]
|
||||
] = None,
|
||||
handler: typing.Callable[[Request, Exception], typing.Any] | None = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
self.app = app
|
||||
@@ -185,9 +186,7 @@ class ServerErrorMiddleware:
|
||||
# to optionally raise the error within the test case.
|
||||
raise exc
|
||||
|
||||
def format_line(
|
||||
self, index: int, line: str, frame_lineno: int, frame_index: int
|
||||
) -> str:
|
||||
def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str:
|
||||
values = {
|
||||
# HTML escape - line could contain < or >
|
||||
"line": html.escape(line).replace(" ", " "),
|
||||
@@ -224,9 +223,7 @@ class ServerErrorMiddleware:
|
||||
return FRAME_TEMPLATE.format(**values)
|
||||
|
||||
def generate_html(self, exc: Exception, limit: int = 7) -> str:
|
||||
traceback_obj = traceback.TracebackException.from_exception(
|
||||
exc, capture_locals=True
|
||||
)
|
||||
traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True)
|
||||
|
||||
exc_html = ""
|
||||
is_collapsed = False
|
||||
@@ -237,11 +234,13 @@ class ServerErrorMiddleware:
|
||||
exc_html += self.generate_frame_html(frame, is_collapsed)
|
||||
is_collapsed = True
|
||||
|
||||
if sys.version_info >= (3, 13): # pragma: no cover
|
||||
exc_type_str = traceback_obj.exc_type_str
|
||||
else: # pragma: no cover
|
||||
exc_type_str = traceback_obj.exc_type.__name__
|
||||
|
||||
# escape error class and text
|
||||
error = (
|
||||
f"{html.escape(traceback_obj.exc_type.__name__)}: "
|
||||
f"{html.escape(str(traceback_obj))}"
|
||||
)
|
||||
error = f"{html.escape(exc_type_str)}: {html.escape(str(traceback_obj))}"
|
||||
|
||||
return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html)
|
||||
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from starlette._exception_handler import (
|
||||
@@ -16,9 +18,7 @@ class ExceptionMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
handlers: typing.Optional[
|
||||
typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
|
||||
] = None,
|
||||
handlers: typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | None = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
self.app = app
|
||||
@@ -28,13 +28,13 @@ class ExceptionMiddleware:
|
||||
HTTPException: self.http_exception,
|
||||
WebSocketException: self.websocket_exception,
|
||||
}
|
||||
if handlers is not None:
|
||||
if handlers is not None: # pragma: no branch
|
||||
for key, value in handlers.items():
|
||||
self.add_exception_handler(key, value)
|
||||
|
||||
def add_exception_handler(
|
||||
self,
|
||||
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
|
||||
exc_class_or_status_code: int | type[Exception],
|
||||
handler: typing.Callable[[Request, Exception], Response],
|
||||
) -> None:
|
||||
if isinstance(exc_class_or_status_code, int):
|
||||
@@ -53,7 +53,7 @@ class ExceptionMiddleware:
|
||||
self._status_handlers,
|
||||
)
|
||||
|
||||
conn: typing.Union[Request, WebSocket]
|
||||
conn: Request | WebSocket
|
||||
if scope["type"] == "http":
|
||||
conn = Request(scope, receive, send)
|
||||
else:
|
||||
@@ -65,9 +65,7 @@ class ExceptionMiddleware:
|
||||
assert isinstance(exc, HTTPException)
|
||||
if exc.status_code in {204, 304}:
|
||||
return Response(status_code=exc.status_code, headers=exc.headers)
|
||||
return PlainTextResponse(
|
||||
exc.detail, status_code=exc.status_code, headers=exc.headers
|
||||
)
|
||||
return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
|
||||
|
||||
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
|
||||
assert isinstance(exc, WebSocketException)
|
||||
|
@@ -7,20 +7,16 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
|
||||
class GZipMiddleware:
|
||||
def __init__(
|
||||
self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9
|
||||
) -> None:
|
||||
def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
|
||||
self.app = app
|
||||
self.minimum_size = minimum_size
|
||||
self.compresslevel = compresslevel
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] == "http":
|
||||
if scope["type"] == "http": # pragma: no branch
|
||||
headers = Headers(scope=scope)
|
||||
if "gzip" in headers.get("Accept-Encoding", ""):
|
||||
responder = GZipResponder(
|
||||
self.app, self.minimum_size, compresslevel=self.compresslevel
|
||||
)
|
||||
responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
|
||||
await responder(scope, receive, send)
|
||||
return
|
||||
await self.app(scope, receive, send)
|
||||
@@ -35,13 +31,12 @@ class GZipResponder:
|
||||
self.started = False
|
||||
self.content_encoding_set = False
|
||||
self.gzip_buffer = io.BytesIO()
|
||||
self.gzip_file = gzip.GzipFile(
|
||||
mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel
|
||||
)
|
||||
self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
self.send = send
|
||||
await self.app(scope, receive, self.send_with_gzip)
|
||||
with self.gzip_buffer, self.gzip_file:
|
||||
await self.app(scope, receive, self.send_with_gzip)
|
||||
|
||||
async def send_with_gzip(self, message: Message) -> None:
|
||||
message_type = message["type"]
|
||||
@@ -93,7 +88,7 @@ class GZipResponder:
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
|
||||
elif message_type == "http.response.body":
|
||||
elif message_type == "http.response.body": # pragma: no branch
|
||||
# Remaining body in streaming GZip response.
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
from base64 import b64decode, b64encode
|
||||
@@ -14,13 +16,13 @@ class SessionMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
secret_key: typing.Union[str, Secret],
|
||||
secret_key: str | Secret,
|
||||
session_cookie: str = "session",
|
||||
max_age: typing.Optional[int] = 14 * 24 * 60 * 60, # 14 days, in seconds
|
||||
max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds
|
||||
path: str = "/",
|
||||
same_site: typing.Literal["lax", "strict", "none"] = "lax",
|
||||
https_only: bool = False,
|
||||
domain: typing.Optional[str] = None,
|
||||
domain: str | None = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.signer = itsdangerous.TimestampSigner(str(secret_key))
|
||||
@@ -59,7 +61,7 @@ class SessionMiddleware:
|
||||
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
|
||||
data = self.signer.sign(data)
|
||||
headers = MutableHeaders(scope=message)
|
||||
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( # noqa E501
|
||||
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
|
||||
session_cookie=self.session_cookie,
|
||||
data=data.decode("utf-8"),
|
||||
path=self.path,
|
||||
@@ -70,7 +72,7 @@ class SessionMiddleware:
|
||||
elif not initial_session_was_empty:
|
||||
# The session has been cleared.
|
||||
headers = MutableHeaders(scope=message)
|
||||
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( # noqa E501
|
||||
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(
|
||||
session_cookie=self.session_cookie,
|
||||
data="null",
|
||||
path=self.path,
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from starlette.datastructures import URL, Headers
|
||||
@@ -11,7 +13,7 @@ class TrustedHostMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allowed_hosts: typing.Optional[typing.Sequence[str]] = None,
|
||||
allowed_hosts: typing.Sequence[str] | None = None,
|
||||
www_redirect: bool = True,
|
||||
) -> None:
|
||||
if allowed_hosts is None:
|
||||
@@ -39,9 +41,7 @@ class TrustedHostMiddleware:
|
||||
is_valid_host = False
|
||||
found_www_redirect = False
|
||||
for pattern in self.allowed_hosts:
|
||||
if host == pattern or (
|
||||
pattern.startswith("*") and host.endswith(pattern[1:])
|
||||
):
|
||||
if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
|
||||
is_valid_host = True
|
||||
break
|
||||
elif "www." + host == pattern:
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import math
|
||||
import sys
|
||||
@@ -16,7 +18,7 @@ warnings.warn(
|
||||
)
|
||||
|
||||
|
||||
def build_environ(scope: Scope, body: bytes) -> typing.Dict[str, typing.Any]:
|
||||
def build_environ(scope: Scope, body: bytes) -> dict[str, typing.Any]:
|
||||
"""
|
||||
Builds a scope and request body into a WSGI environ object.
|
||||
"""
|
||||
@@ -87,9 +89,7 @@ class WSGIResponder:
|
||||
self.scope = scope
|
||||
self.status = None
|
||||
self.response_headers = None
|
||||
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(
|
||||
math.inf
|
||||
)
|
||||
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
|
||||
self.response_started = False
|
||||
self.exc_info: typing.Any = None
|
||||
|
||||
@@ -117,11 +117,11 @@ class WSGIResponder:
|
||||
def start_response(
|
||||
self,
|
||||
status: str,
|
||||
response_headers: typing.List[typing.Tuple[str, str]],
|
||||
response_headers: list[tuple[str, str]],
|
||||
exc_info: typing.Any = None,
|
||||
) -> None:
|
||||
self.exc_info = exc_info
|
||||
if not self.response_started:
|
||||
if not self.response_started: # pragma: no branch
|
||||
self.response_started = True
|
||||
status_code_string, _ = status.split(" ", 1)
|
||||
status_code = int(status_code_string)
|
||||
@@ -140,7 +140,7 @@ class WSGIResponder:
|
||||
|
||||
def wsgi(
|
||||
self,
|
||||
environ: typing.Dict[str, typing.Any],
|
||||
environ: dict[str, typing.Any],
|
||||
start_response: typing.Callable[..., typing.Any],
|
||||
) -> None:
|
||||
for chunk in self.app(environ, start_response):
|
||||
@@ -149,6 +149,4 @@ class WSGIResponder:
|
||||
{"type": "http.response.body", "body": chunk, "more_body": True},
|
||||
)
|
||||
|
||||
anyio.from_thread.run(
|
||||
self.stream_send.send, {"type": "http.response.body", "body": b""}
|
||||
)
|
||||
anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""})
|
||||
|
@@ -12,14 +12,19 @@ from starlette.exceptions import HTTPException
|
||||
from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
|
||||
from starlette.types import Message, Receive, Scope, Send
|
||||
|
||||
try:
|
||||
from multipart.multipart import parse_options_header
|
||||
except ModuleNotFoundError: # pragma: nocover
|
||||
parse_options_header = None
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from python_multipart.multipart import parse_options_header
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Router
|
||||
else:
|
||||
try:
|
||||
try:
|
||||
from python_multipart.multipart import parse_options_header
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
from multipart.multipart import parse_options_header
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
parse_options_header = None
|
||||
|
||||
|
||||
SERVER_PUSH_HEADERS_TO_COPY = {
|
||||
@@ -43,7 +48,7 @@ def cookie_parser(cookie_string: str) -> dict[str, str]:
|
||||
Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
|
||||
on an outdated spec and will fail on lots of input we want to support
|
||||
"""
|
||||
cookie_dict: typing.Dict[str, str] = {}
|
||||
cookie_dict: dict[str, str] = {}
|
||||
for chunk in cookie_string.split(";"):
|
||||
if "=" in chunk:
|
||||
key, val = chunk.split("=", 1)
|
||||
@@ -93,7 +98,7 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
|
||||
|
||||
@property
|
||||
def url(self) -> URL:
|
||||
if not hasattr(self, "_url"):
|
||||
if not hasattr(self, "_url"): # pragma: no branch
|
||||
self._url = URL(scope=self.scope)
|
||||
return self._url
|
||||
|
||||
@@ -104,9 +109,7 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
|
||||
# This is used by request.url_for, it might be used inside a Mount which
|
||||
# would have its own child scope with its own root_path, but the base URL
|
||||
# for url_for should still be the top level app root path.
|
||||
app_root_path = base_url_scope.get(
|
||||
"app_root_path", base_url_scope.get("root_path", "")
|
||||
)
|
||||
app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", ""))
|
||||
path = app_root_path
|
||||
if not path.endswith("/"):
|
||||
path += "/"
|
||||
@@ -124,7 +127,7 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
|
||||
|
||||
@property
|
||||
def query_params(self) -> QueryParams:
|
||||
if not hasattr(self, "_query_params"):
|
||||
if not hasattr(self, "_query_params"): # pragma: no branch
|
||||
self._query_params = QueryParams(self.scope["query_string"])
|
||||
return self._query_params
|
||||
|
||||
@@ -135,7 +138,7 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
|
||||
@property
|
||||
def cookies(self) -> dict[str, str]:
|
||||
if not hasattr(self, "_cookies"):
|
||||
cookies: typing.Dict[str, str] = {}
|
||||
cookies: dict[str, str] = {}
|
||||
cookie_header = self.headers.get("cookie")
|
||||
|
||||
if cookie_header:
|
||||
@@ -145,7 +148,7 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
|
||||
|
||||
@property
|
||||
def client(self) -> Address | None:
|
||||
# client is a 2 item tuple of (host, port), None or missing
|
||||
# client is a 2 item tuple of (host, port), None if missing
|
||||
host_port = self.scope.get("client")
|
||||
if host_port is not None:
|
||||
return Address(*host_port)
|
||||
@@ -153,23 +156,17 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
|
||||
|
||||
@property
|
||||
def session(self) -> dict[str, typing.Any]:
|
||||
assert (
|
||||
"session" in self.scope
|
||||
), "SessionMiddleware must be installed to access request.session"
|
||||
assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
|
||||
return self.scope["session"] # type: ignore[no-any-return]
|
||||
|
||||
@property
|
||||
def auth(self) -> typing.Any:
|
||||
assert (
|
||||
"auth" in self.scope
|
||||
), "AuthenticationMiddleware must be installed to access request.auth"
|
||||
assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
|
||||
return self.scope["auth"]
|
||||
|
||||
@property
|
||||
def user(self) -> typing.Any:
|
||||
assert (
|
||||
"user" in self.scope
|
||||
), "AuthenticationMiddleware must be installed to access request.user"
|
||||
assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
|
||||
return self.scope["user"]
|
||||
|
||||
@property
|
||||
@@ -183,8 +180,10 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
|
||||
return self._state
|
||||
|
||||
def url_for(self, name: str, /, **path_params: typing.Any) -> URL:
|
||||
router: Router = self.scope["router"]
|
||||
url_path = router.url_path_for(name, **path_params)
|
||||
url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app")
|
||||
if url_path_provider is None:
|
||||
raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.")
|
||||
url_path = url_path_provider.url_path_for(name, **path_params)
|
||||
return url_path.make_absolute_url(base_url=self.base_url)
|
||||
|
||||
|
||||
@@ -197,11 +196,9 @@ async def empty_send(message: Message) -> typing.NoReturn:
|
||||
|
||||
|
||||
class Request(HTTPConnection):
|
||||
_form: typing.Optional[FormData]
|
||||
_form: FormData | None
|
||||
|
||||
def __init__(
|
||||
self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
|
||||
):
|
||||
def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
|
||||
super().__init__(scope)
|
||||
assert scope["type"] == "http"
|
||||
self._receive = receive
|
||||
@@ -233,29 +230,33 @@ class Request(HTTPConnection):
|
||||
self._stream_consumed = True
|
||||
if body:
|
||||
yield body
|
||||
elif message["type"] == "http.disconnect":
|
||||
elif message["type"] == "http.disconnect": # pragma: no branch
|
||||
self._is_disconnected = True
|
||||
raise ClientDisconnect()
|
||||
yield b""
|
||||
|
||||
async def body(self) -> bytes:
|
||||
if not hasattr(self, "_body"):
|
||||
chunks: "typing.List[bytes]" = []
|
||||
chunks: list[bytes] = []
|
||||
async for chunk in self.stream():
|
||||
chunks.append(chunk)
|
||||
self._body = b"".join(chunks)
|
||||
return self._body
|
||||
|
||||
async def json(self) -> typing.Any:
|
||||
if not hasattr(self, "_json"):
|
||||
if not hasattr(self, "_json"): # pragma: no branch
|
||||
body = await self.body()
|
||||
self._json = json.loads(body)
|
||||
return self._json
|
||||
|
||||
async def _get_form(
|
||||
self, *, max_files: int | float = 1000, max_fields: int | float = 1000
|
||||
self,
|
||||
*,
|
||||
max_files: int | float = 1000,
|
||||
max_fields: int | float = 1000,
|
||||
max_part_size: int = 1024 * 1024,
|
||||
) -> FormData:
|
||||
if self._form is None:
|
||||
if self._form is None: # pragma: no branch
|
||||
assert (
|
||||
parse_options_header is not None
|
||||
), "The `python-multipart` library must be installed to use form parsing."
|
||||
@@ -269,6 +270,7 @@ class Request(HTTPConnection):
|
||||
self.stream(),
|
||||
max_files=max_files,
|
||||
max_fields=max_fields,
|
||||
max_part_size=max_part_size,
|
||||
)
|
||||
self._form = await multipart_parser.parse()
|
||||
except MultiPartException as exc:
|
||||
@@ -283,14 +285,18 @@ class Request(HTTPConnection):
|
||||
return self._form
|
||||
|
||||
def form(
|
||||
self, *, max_files: int | float = 1000, max_fields: int | float = 1000
|
||||
self,
|
||||
*,
|
||||
max_files: int | float = 1000,
|
||||
max_fields: int | float = 1000,
|
||||
max_part_size: int = 1024 * 1024,
|
||||
) -> AwaitableOrContextManager[FormData]:
|
||||
return AwaitableOrContextManagerWrapper(
|
||||
self._get_form(max_files=max_files, max_fields=max_fields)
|
||||
self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size)
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._form is not None:
|
||||
if self._form is not None: # pragma: no branch
|
||||
await self._form.close()
|
||||
|
||||
async def is_disconnected(self) -> bool:
|
||||
@@ -309,12 +315,8 @@ class Request(HTTPConnection):
|
||||
|
||||
async def send_push_promise(self, path: str) -> None:
|
||||
if "http.response.push" in self.scope.get("extensions", {}):
|
||||
raw_headers: "typing.List[typing.Tuple[bytes, bytes]]" = []
|
||||
raw_headers: list[tuple[bytes, bytes]] = []
|
||||
for name in SERVER_PUSH_HEADERS_TO_COPY:
|
||||
for value in self.headers.getlist(name):
|
||||
raw_headers.append(
|
||||
(name.encode("latin-1"), value.encode("latin-1"))
|
||||
)
|
||||
await self._send(
|
||||
{"type": "http.response.push", "path": path, "headers": raw_headers}
|
||||
)
|
||||
raw_headers.append((name.encode("latin-1"), value.encode("latin-1")))
|
||||
await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})
|
||||
|
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import http.cookies
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import stat
|
||||
import typing
|
||||
import warnings
|
||||
@@ -10,15 +12,16 @@ from datetime import datetime
|
||||
from email.utils import format_datetime, formatdate
|
||||
from functools import partial
|
||||
from mimetypes import guess_type
|
||||
from secrets import token_hex
|
||||
from urllib.parse import quote
|
||||
|
||||
import anyio
|
||||
import anyio.to_thread
|
||||
|
||||
from starlette._compat import md5_hexdigest
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.datastructures import URL, MutableHeaders
|
||||
from starlette.datastructures import URL, Headers, MutableHeaders
|
||||
from starlette.requests import ClientDisconnect
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
|
||||
@@ -41,10 +44,10 @@ class Response:
|
||||
self.body = self.render(content)
|
||||
self.init_headers(headers)
|
||||
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
def render(self, content: typing.Any) -> bytes | memoryview:
|
||||
if content is None:
|
||||
return b""
|
||||
if isinstance(content, bytes):
|
||||
if isinstance(content, (bytes, memoryview)):
|
||||
return content
|
||||
return content.encode(self.charset) # type: ignore
|
||||
|
||||
@@ -54,10 +57,7 @@ class Response:
|
||||
populate_content_length = True
|
||||
populate_content_type = True
|
||||
else:
|
||||
raw_headers = [
|
||||
(k.lower().encode("latin-1"), v.encode("latin-1"))
|
||||
for k, v in headers.items()
|
||||
]
|
||||
raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()]
|
||||
keys = [h[0] for h in raw_headers]
|
||||
populate_content_length = b"content-length" not in keys
|
||||
populate_content_type = b"content-type" not in keys
|
||||
@@ -73,10 +73,7 @@ class Response:
|
||||
|
||||
content_type = self.media_type
|
||||
if content_type is not None and populate_content_type:
|
||||
if (
|
||||
content_type.startswith("text/")
|
||||
and "charset=" not in content_type.lower()
|
||||
):
|
||||
if content_type.startswith("text/") and "charset=" not in content_type.lower():
|
||||
content_type += "; charset=" + self.charset
|
||||
raw_headers.append((b"content-type", content_type.encode("latin-1")))
|
||||
|
||||
@@ -94,7 +91,7 @@ class Response:
|
||||
value: str = "",
|
||||
max_age: int | None = None,
|
||||
expires: datetime | str | int | None = None,
|
||||
path: str = "/",
|
||||
path: str | None = "/",
|
||||
domain: str | None = None,
|
||||
secure: bool = False,
|
||||
httponly: bool = False,
|
||||
@@ -148,14 +145,15 @@ class Response:
|
||||
)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
prefix = "websocket." if scope["type"] == "websocket" else ""
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"type": prefix + "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": self.body})
|
||||
await send({"type": prefix + "http.response.body", "body": self.body})
|
||||
|
||||
if self.background is not None:
|
||||
await self.background()
|
||||
@@ -200,13 +198,11 @@ class RedirectResponse(Response):
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
content=b"", status_code=status_code, headers=headers, background=background
|
||||
)
|
||||
super().__init__(content=b"", status_code=status_code, headers=headers, background=background)
|
||||
self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
|
||||
|
||||
|
||||
Content = typing.Union[str, bytes]
|
||||
Content = typing.Union[str, bytes, memoryview]
|
||||
SyncContentStream = typing.Iterable[Content]
|
||||
AsyncContentStream = typing.AsyncIterable[Content]
|
||||
ContentStream = typing.Union[AsyncContentStream, SyncContentStream]
|
||||
@@ -247,26 +243,47 @@ class StreamingResponse(Response):
|
||||
}
|
||||
)
|
||||
async for chunk in self.body_iterator:
|
||||
if not isinstance(chunk, bytes):
|
||||
if not isinstance(chunk, (bytes, memoryview)):
|
||||
chunk = chunk.encode(self.charset)
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
||||
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
async with anyio.create_task_group() as task_group:
|
||||
spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split(".")))
|
||||
|
||||
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
|
||||
await func()
|
||||
task_group.cancel_scope.cancel()
|
||||
if spec_version >= (2, 4):
|
||||
try:
|
||||
await self.stream_response(send)
|
||||
except OSError:
|
||||
raise ClientDisconnect()
|
||||
else:
|
||||
async with anyio.create_task_group() as task_group:
|
||||
|
||||
task_group.start_soon(wrap, partial(self.stream_response, send))
|
||||
await wrap(partial(self.listen_for_disconnect, receive))
|
||||
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
|
||||
await func()
|
||||
task_group.cancel_scope.cancel()
|
||||
|
||||
task_group.start_soon(wrap, partial(self.stream_response, send))
|
||||
await wrap(partial(self.listen_for_disconnect, receive))
|
||||
|
||||
if self.background is not None:
|
||||
await self.background()
|
||||
|
||||
|
||||
class MalformedRangeHeader(Exception):
|
||||
def __init__(self, content: str = "Malformed range header.") -> None:
|
||||
self.content = content
|
||||
|
||||
|
||||
class RangeNotSatisfiable(Exception):
|
||||
def __init__(self, max_size: int) -> None:
|
||||
self.max_size = max_size
|
||||
|
||||
|
||||
_RANGE_PATTERN = re.compile(r"(\d*)-(\d*)")
|
||||
|
||||
|
||||
class FileResponse(Response):
|
||||
chunk_size = 64 * 1024
|
||||
|
||||
@@ -295,16 +312,13 @@ class FileResponse(Response):
|
||||
self.media_type = media_type
|
||||
self.background = background
|
||||
self.init_headers(headers)
|
||||
self.headers.setdefault("accept-ranges", "bytes")
|
||||
if self.filename is not None:
|
||||
content_disposition_filename = quote(self.filename)
|
||||
if content_disposition_filename != self.filename:
|
||||
content_disposition = "{}; filename*=utf-8''{}".format(
|
||||
content_disposition_type, content_disposition_filename
|
||||
)
|
||||
content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}"
|
||||
else:
|
||||
content_disposition = '{}; filename="{}"'.format(
|
||||
content_disposition_type, self.filename
|
||||
)
|
||||
content_disposition = f'{content_disposition_type}; filename="{self.filename}"'
|
||||
self.headers.setdefault("content-disposition", content_disposition)
|
||||
self.stat_result = stat_result
|
||||
if stat_result is not None:
|
||||
@@ -314,13 +328,14 @@ class FileResponse(Response):
|
||||
content_length = str(stat_result.st_size)
|
||||
last_modified = formatdate(stat_result.st_mtime, usegmt=True)
|
||||
etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size)
|
||||
etag = f'"{md5_hexdigest(etag_base.encode(), usedforsecurity=False)}"'
|
||||
etag = f'"{hashlib.md5(etag_base.encode(), usedforsecurity=False).hexdigest()}"'
|
||||
|
||||
self.headers.setdefault("content-length", content_length)
|
||||
self.headers.setdefault("last-modified", last_modified)
|
||||
self.headers.setdefault("etag", etag)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
send_header_only: bool = scope["method"].upper() == "HEAD"
|
||||
if self.stat_result is None:
|
||||
try:
|
||||
stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
|
||||
@@ -331,29 +346,192 @@ class FileResponse(Response):
|
||||
mode = stat_result.st_mode
|
||||
if not stat.S_ISREG(mode):
|
||||
raise RuntimeError(f"File at path {self.path} is not a file.")
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
if scope["method"].upper() == "HEAD":
|
||||
else:
|
||||
stat_result = self.stat_result
|
||||
|
||||
headers = Headers(scope=scope)
|
||||
http_range = headers.get("range")
|
||||
http_if_range = headers.get("if-range")
|
||||
|
||||
if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)):
|
||||
await self._handle_simple(send, send_header_only)
|
||||
else:
|
||||
try:
|
||||
ranges = self._parse_range_header(http_range, stat_result.st_size)
|
||||
except MalformedRangeHeader as exc:
|
||||
return await PlainTextResponse(exc.content, status_code=400)(scope, receive, send)
|
||||
except RangeNotSatisfiable as exc:
|
||||
response = PlainTextResponse(status_code=416, headers={"Content-Range": f"*/{exc.max_size}"})
|
||||
return await response(scope, receive, send)
|
||||
|
||||
if len(ranges) == 1:
|
||||
start, end = ranges[0]
|
||||
await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only)
|
||||
else:
|
||||
await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only)
|
||||
|
||||
if self.background is not None:
|
||||
await self.background()
|
||||
|
||||
async def _handle_simple(self, send: Send, send_header_only: bool) -> None:
|
||||
await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers})
|
||||
if send_header_only:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
elif "extensions" in scope and "http.response.pathsend" in scope["extensions"]:
|
||||
await send({"type": "http.response.pathsend", "path": str(self.path)})
|
||||
else:
|
||||
async with await anyio.open_file(self.path, mode="rb") as file:
|
||||
more_body = True
|
||||
while more_body:
|
||||
chunk = await file.read(self.chunk_size)
|
||||
more_body = len(chunk) == self.chunk_size
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": chunk,
|
||||
"more_body": more_body,
|
||||
}
|
||||
)
|
||||
if self.background is not None:
|
||||
await self.background()
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
|
||||
|
||||
async def _handle_single_range(
|
||||
self, send: Send, start: int, end: int, file_size: int, send_header_only: bool
|
||||
) -> None:
|
||||
self.headers["content-range"] = f"bytes {start}-{end - 1}/{file_size}"
|
||||
self.headers["content-length"] = str(end - start)
|
||||
await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers})
|
||||
if send_header_only:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
else:
|
||||
async with await anyio.open_file(self.path, mode="rb") as file:
|
||||
await file.seek(start)
|
||||
more_body = True
|
||||
while more_body:
|
||||
chunk = await file.read(min(self.chunk_size, end - start))
|
||||
start += len(chunk)
|
||||
more_body = len(chunk) == self.chunk_size and start < end
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
|
||||
|
||||
async def _handle_multiple_ranges(
|
||||
self,
|
||||
send: Send,
|
||||
ranges: list[tuple[int, int]],
|
||||
file_size: int,
|
||||
send_header_only: bool,
|
||||
) -> None:
|
||||
# In firefox and chrome, they use boundary with 95-96 bits entropy (that's roughly 13 bytes).
|
||||
boundary = token_hex(13)
|
||||
content_length, header_generator = self.generate_multipart(
|
||||
ranges, boundary, file_size, self.headers["content-type"]
|
||||
)
|
||||
self.headers["content-range"] = f"multipart/byteranges; boundary={boundary}"
|
||||
self.headers["content-length"] = str(content_length)
|
||||
await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers})
|
||||
if send_header_only:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
else:
|
||||
async with await anyio.open_file(self.path, mode="rb") as file:
|
||||
for start, end in ranges:
|
||||
await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True})
|
||||
await file.seek(start)
|
||||
while start < end:
|
||||
chunk = await file.read(min(self.chunk_size, end - start))
|
||||
start += len(chunk)
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
||||
await send({"type": "http.response.body", "body": b"\n", "more_body": True})
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": f"\n--{boundary}--\n".encode("latin-1"),
|
||||
"more_body": False,
|
||||
}
|
||||
)
|
||||
|
||||
def _should_use_range(self, http_if_range: str) -> bool:
|
||||
return http_if_range == self.headers["last-modified"] or http_if_range == self.headers["etag"]
|
||||
|
||||
@staticmethod
|
||||
def _parse_range_header(http_range: str, file_size: int) -> list[tuple[int, int]]:
|
||||
ranges: list[tuple[int, int]] = []
|
||||
try:
|
||||
units, range_ = http_range.split("=", 1)
|
||||
except ValueError:
|
||||
raise MalformedRangeHeader()
|
||||
|
||||
units = units.strip().lower()
|
||||
|
||||
if units != "bytes":
|
||||
raise MalformedRangeHeader("Only support bytes range")
|
||||
|
||||
ranges = [
|
||||
(
|
||||
int(_[0]) if _[0] else file_size - int(_[1]),
|
||||
int(_[1]) + 1 if _[0] and _[1] and int(_[1]) < file_size else file_size,
|
||||
)
|
||||
for _ in _RANGE_PATTERN.findall(range_)
|
||||
if _ != ("", "")
|
||||
]
|
||||
|
||||
if len(ranges) == 0:
|
||||
raise MalformedRangeHeader("Range header: range must be requested")
|
||||
|
||||
if any(not (0 <= start < file_size) for start, _ in ranges):
|
||||
raise RangeNotSatisfiable(file_size)
|
||||
|
||||
if any(start > end for start, end in ranges):
|
||||
raise MalformedRangeHeader("Range header: start must be less than end")
|
||||
|
||||
if len(ranges) == 1:
|
||||
return ranges
|
||||
|
||||
# Merge ranges
|
||||
result: list[tuple[int, int]] = []
|
||||
for start, end in ranges:
|
||||
for p in range(len(result)):
|
||||
p_start, p_end = result[p]
|
||||
if start > p_end:
|
||||
continue
|
||||
elif end < p_start:
|
||||
result.insert(p, (start, end)) # THIS IS NOT REACHED!
|
||||
break
|
||||
else:
|
||||
result[p] = (min(start, p_start), max(end, p_end))
|
||||
break
|
||||
else:
|
||||
result.append((start, end))
|
||||
|
||||
return result
|
||||
|
||||
def generate_multipart(
|
||||
self,
|
||||
ranges: typing.Sequence[tuple[int, int]],
|
||||
boundary: str,
|
||||
max_size: int,
|
||||
content_type: str,
|
||||
) -> tuple[int, typing.Callable[[int, int], bytes]]:
|
||||
r"""
|
||||
Multipart response headers generator.
|
||||
|
||||
```
|
||||
--{boundary}\n
|
||||
Content-Type: {content_type}\n
|
||||
Content-Range: bytes {start}-{end-1}/{max_size}\n
|
||||
\n
|
||||
..........content...........\n
|
||||
--{boundary}\n
|
||||
Content-Type: {content_type}\n
|
||||
Content-Range: bytes {start}-{end-1}/{max_size}\n
|
||||
\n
|
||||
..........content...........\n
|
||||
--{boundary}--\n
|
||||
```
|
||||
"""
|
||||
boundary_len = len(boundary)
|
||||
static_header_part_len = 44 + boundary_len + len(content_type) + len(str(max_size))
|
||||
content_length = sum(
|
||||
(len(str(start)) + len(str(end - 1)) + static_header_part_len) # Headers
|
||||
+ (end - start) # Content
|
||||
for start, end in ranges
|
||||
) + (
|
||||
5 + boundary_len # --boundary--\n
|
||||
)
|
||||
return (
|
||||
content_length,
|
||||
lambda start, end: (
|
||||
f"--{boundary}\n"
|
||||
f"Content-Type: {content_type}\n"
|
||||
f"Content-Range: bytes {start}-{end-1}/{max_size}\n"
|
||||
"\n"
|
||||
).encode("latin-1"),
|
||||
)
|
||||
|
@@ -47,8 +47,7 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
|
||||
including those wrapped in functools.partial objects.
|
||||
"""
|
||||
warnings.warn(
|
||||
"iscoroutinefunction_or_partial is deprecated, "
|
||||
"and will be removed in a future release.",
|
||||
"iscoroutinefunction_or_partial is deprecated, and will be removed in a future release.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
while isinstance(obj, functools.partial):
|
||||
@@ -57,23 +56,21 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
|
||||
|
||||
|
||||
def request_response(
|
||||
func: typing.Callable[
|
||||
[Request], typing.Union[typing.Awaitable[Response], Response]
|
||||
],
|
||||
func: typing.Callable[[Request], typing.Awaitable[Response] | Response],
|
||||
) -> ASGIApp:
|
||||
"""
|
||||
Takes a function or coroutine `func(request) -> response`,
|
||||
and returns an ASGI application.
|
||||
"""
|
||||
f: typing.Callable[[Request], typing.Awaitable[Response]] = (
|
||||
func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore
|
||||
)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
request = Request(scope, receive, send)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if is_async_callable(func):
|
||||
response = await func(request)
|
||||
else:
|
||||
response = await run_in_threadpool(func, request)
|
||||
response = await f(request)
|
||||
await response(scope, receive, send)
|
||||
|
||||
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
|
||||
@@ -101,9 +98,7 @@ def websocket_session(
|
||||
|
||||
|
||||
def get_name(endpoint: typing.Callable[..., typing.Any]) -> str:
|
||||
if inspect.isroutine(endpoint) or inspect.isclass(endpoint):
|
||||
return endpoint.__name__
|
||||
return endpoint.__class__.__name__
|
||||
return getattr(endpoint, "__name__", endpoint.__class__.__name__)
|
||||
|
||||
|
||||
def replace_params(
|
||||
@@ -147,9 +142,7 @@ def compile_path(
|
||||
for match in PARAM_REGEX.finditer(path):
|
||||
param_name, convertor_type = match.groups("str")
|
||||
convertor_type = convertor_type.lstrip(":")
|
||||
assert (
|
||||
convertor_type in CONVERTOR_TYPES
|
||||
), f"Unknown path convertor '{convertor_type}'"
|
||||
assert convertor_type in CONVERTOR_TYPES, f"Unknown path convertor '{convertor_type}'"
|
||||
convertor = CONVERTOR_TYPES[convertor_type]
|
||||
|
||||
path_regex += re.escape(path[idx : match.start()])
|
||||
@@ -203,7 +196,7 @@ class BaseRoute:
|
||||
if scope["type"] == "http":
|
||||
response = PlainTextResponse("Not Found", status_code=404)
|
||||
await response(scope, receive, send)
|
||||
elif scope["type"] == "websocket":
|
||||
elif scope["type"] == "websocket": # pragma: no branch
|
||||
websocket_close = WebSocketClose()
|
||||
await websocket_close(scope, receive, send)
|
||||
return
|
||||
@@ -243,7 +236,7 @@ class Route(BaseRoute):
|
||||
|
||||
if middleware is not None:
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
self.app = cls(app=self.app, *args, **kwargs)
|
||||
self.app = cls(self.app, *args, **kwargs)
|
||||
|
||||
if methods is None:
|
||||
self.methods = None
|
||||
@@ -255,7 +248,7 @@ class Route(BaseRoute):
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
path_params: "typing.Dict[str, typing.Any]"
|
||||
path_params: dict[str, typing.Any]
|
||||
if scope["type"] == "http":
|
||||
route_path = get_route_path(scope)
|
||||
match = self.path_regex.match(route_path)
|
||||
@@ -279,9 +272,7 @@ class Route(BaseRoute):
|
||||
if name != self.name or seen_params != expected_params:
|
||||
raise NoMatchFound(name, path_params)
|
||||
|
||||
path, remaining_params = replace_params(
|
||||
self.path_format, self.param_convertors, path_params
|
||||
)
|
||||
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
|
||||
assert not remaining_params
|
||||
return URLPath(path=path, protocol="http")
|
||||
|
||||
@@ -291,9 +282,7 @@ class Route(BaseRoute):
|
||||
if "app" in scope:
|
||||
raise HTTPException(status_code=405, headers=headers)
|
||||
else:
|
||||
response = PlainTextResponse(
|
||||
"Method Not Allowed", status_code=405, headers=headers
|
||||
)
|
||||
response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
|
||||
await response(scope, receive, send)
|
||||
else:
|
||||
await self.app(scope, receive, send)
|
||||
@@ -339,12 +328,12 @@ class WebSocketRoute(BaseRoute):
|
||||
|
||||
if middleware is not None:
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
self.app = cls(app=self.app, *args, **kwargs)
|
||||
self.app = cls(self.app, *args, **kwargs)
|
||||
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
path_params: "typing.Dict[str, typing.Any]"
|
||||
path_params: dict[str, typing.Any]
|
||||
if scope["type"] == "websocket":
|
||||
route_path = get_route_path(scope)
|
||||
match = self.path_regex.match(route_path)
|
||||
@@ -365,9 +354,7 @@ class WebSocketRoute(BaseRoute):
|
||||
if name != self.name or seen_params != expected_params:
|
||||
raise NoMatchFound(name, path_params)
|
||||
|
||||
path, remaining_params = replace_params(
|
||||
self.path_format, self.param_convertors, path_params
|
||||
)
|
||||
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
|
||||
assert not remaining_params
|
||||
return URLPath(path=path, protocol="websocket")
|
||||
|
||||
@@ -375,11 +362,7 @@ class WebSocketRoute(BaseRoute):
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, WebSocketRoute)
|
||||
and self.path == other.path
|
||||
and self.endpoint == other.endpoint
|
||||
)
|
||||
return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})"
|
||||
@@ -396,9 +379,7 @@ class Mount(BaseRoute):
|
||||
middleware: typing.Sequence[Middleware] | None = None,
|
||||
) -> None:
|
||||
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
|
||||
assert (
|
||||
app is not None or routes is not None
|
||||
), "Either 'app=...', or 'routes=' must be specified"
|
||||
assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified"
|
||||
self.path = path.rstrip("/")
|
||||
if app is not None:
|
||||
self._base_app: ASGIApp = app
|
||||
@@ -407,19 +388,17 @@ class Mount(BaseRoute):
|
||||
self.app = self._base_app
|
||||
if middleware is not None:
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
self.app = cls(app=self.app, *args, **kwargs)
|
||||
self.app = cls(self.app, *args, **kwargs)
|
||||
self.name = name
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(
|
||||
self.path + "/{path:path}"
|
||||
)
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}")
|
||||
|
||||
@property
|
||||
def routes(self) -> list[BaseRoute]:
|
||||
return getattr(self._base_app, "routes", [])
|
||||
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
path_params: "typing.Dict[str, typing.Any]"
|
||||
if scope["type"] in ("http", "websocket"):
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
path_params: dict[str, typing.Any]
|
||||
if scope["type"] in ("http", "websocket"): # pragma: no branch
|
||||
root_path = scope.get("root_path", "")
|
||||
route_path = get_route_path(scope)
|
||||
match = self.path_regex.match(route_path)
|
||||
@@ -454,9 +433,7 @@ class Mount(BaseRoute):
|
||||
if self.name is not None and name == self.name and "path" in path_params:
|
||||
# 'name' matches "<mount_name>".
|
||||
path_params["path"] = path_params["path"].lstrip("/")
|
||||
path, remaining_params = replace_params(
|
||||
self.path_format, self.param_convertors, path_params
|
||||
)
|
||||
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
|
||||
if not remaining_params:
|
||||
return URLPath(path=path)
|
||||
elif self.name is None or name.startswith(self.name + ":"):
|
||||
@@ -468,17 +445,13 @@ class Mount(BaseRoute):
|
||||
remaining_name = name[len(self.name) + 1 :]
|
||||
path_kwarg = path_params.get("path")
|
||||
path_params["path"] = ""
|
||||
path_prefix, remaining_params = replace_params(
|
||||
self.path_format, self.param_convertors, path_params
|
||||
)
|
||||
path_prefix, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
|
||||
if path_kwarg is not None:
|
||||
remaining_params["path"] = path_kwarg
|
||||
for route in self.routes or []:
|
||||
try:
|
||||
url = route.url_path_for(remaining_name, **remaining_params)
|
||||
return URLPath(
|
||||
path=path_prefix.rstrip("/") + str(url), protocol=url.protocol
|
||||
)
|
||||
return URLPath(path=path_prefix.rstrip("/") + str(url), protocol=url.protocol)
|
||||
except NoMatchFound:
|
||||
pass
|
||||
raise NoMatchFound(name, path_params)
|
||||
@@ -487,11 +460,7 @@ class Mount(BaseRoute):
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, Mount)
|
||||
and self.path == other.path
|
||||
and self.app == other.app
|
||||
)
|
||||
return isinstance(other, Mount) and self.path == other.path and self.app == other.app
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
@@ -512,7 +481,7 @@ class Host(BaseRoute):
|
||||
return getattr(self.app, "routes", [])
|
||||
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
if scope["type"] in ("http", "websocket"):
|
||||
if scope["type"] in ("http", "websocket"): # pragma:no branch
|
||||
headers = Headers(scope=scope)
|
||||
host = headers.get("host", "").split(":")[0]
|
||||
match = self.host_regex.match(host)
|
||||
@@ -530,9 +499,7 @@ class Host(BaseRoute):
|
||||
if self.name is not None and name == self.name and "path" in path_params:
|
||||
# 'name' matches "<mount_name>".
|
||||
path = path_params.pop("path")
|
||||
host, remaining_params = replace_params(
|
||||
self.host_format, self.param_convertors, path_params
|
||||
)
|
||||
host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
|
||||
if not remaining_params:
|
||||
return URLPath(path=path, host=host)
|
||||
elif self.name is None or name.startswith(self.name + ":"):
|
||||
@@ -542,9 +509,7 @@ class Host(BaseRoute):
|
||||
else:
|
||||
# 'name' matches "<mount_name>:<child_name>".
|
||||
remaining_name = name[len(self.name) + 1 :]
|
||||
host, remaining_params = replace_params(
|
||||
self.host_format, self.param_convertors, path_params
|
||||
)
|
||||
host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
|
||||
for route in self.routes or []:
|
||||
try:
|
||||
url = route.url_path_for(remaining_name, **remaining_params)
|
||||
@@ -557,11 +522,7 @@ class Host(BaseRoute):
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, Host)
|
||||
and self.host == other.host
|
||||
and self.app == other.app
|
||||
)
|
||||
return isinstance(other, Host) and self.host == other.host and self.app == other.app
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
@@ -589,9 +550,7 @@ class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
|
||||
|
||||
|
||||
def _wrap_gen_lifespan_context(
|
||||
lifespan_context: typing.Callable[
|
||||
[typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]
|
||||
],
|
||||
lifespan_context: typing.Callable[[typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]],
|
||||
) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]:
|
||||
cmgr = contextlib.contextmanager(lifespan_context)
|
||||
|
||||
@@ -734,9 +693,7 @@ class Router:
|
||||
async with self.lifespan_context(app) as maybe_state:
|
||||
if maybe_state is not None:
|
||||
if "state" not in scope:
|
||||
raise RuntimeError(
|
||||
'The server does not support "state" in the lifespan scope.'
|
||||
)
|
||||
raise RuntimeError('The server does not support "state" in the lifespan scope.')
|
||||
scope["state"].update(maybe_state)
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
started = True
|
||||
@@ -810,15 +767,11 @@ class Router:
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return isinstance(other, Router) and self.routes == other.routes
|
||||
|
||||
def mount(
|
||||
self, path: str, app: ASGIApp, name: str | None = None
|
||||
) -> None: # pragma: nocover
|
||||
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
|
||||
route = Mount(path, app=app, name=name)
|
||||
self.routes.append(route)
|
||||
|
||||
def host(
|
||||
self, host: str, app: ASGIApp, name: str | None = None
|
||||
) -> None: # pragma: no cover
|
||||
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
|
||||
route = Host(host, app=app, name=name)
|
||||
self.routes.append(route)
|
||||
|
||||
@@ -829,7 +782,7 @@ class Router:
|
||||
methods: list[str] | None = None,
|
||||
name: str | None = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> None: # pragma: nocover
|
||||
) -> None: # pragma: no cover
|
||||
route = Route(
|
||||
path,
|
||||
endpoint=endpoint,
|
||||
@@ -864,11 +817,11 @@ class Router:
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `route` decorator is deprecated, and will be removed in version 1.0.0."
|
||||
"Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.", # noqa: E501
|
||||
"Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
|
||||
self.add_route(
|
||||
path,
|
||||
func,
|
||||
@@ -889,20 +842,18 @@ class Router:
|
||||
>>> app = Starlette(routes=routes)
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to " # noqa: E501
|
||||
"https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501
|
||||
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to "
|
||||
"https://www.starlette.io/routing/#websocket-routing for the recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
|
||||
self.add_websocket_route(path, func, name=name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def add_event_handler(
|
||||
self, event_type: str, func: typing.Callable[[], typing.Any]
|
||||
) -> None: # pragma: no cover
|
||||
def add_event_handler(self, event_type: str, func: typing.Callable[[], typing.Any]) -> None: # pragma: no cover
|
||||
assert event_type in ("startup", "shutdown")
|
||||
|
||||
if event_type == "startup":
|
||||
@@ -912,12 +863,12 @@ class Router:
|
||||
|
||||
def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
|
||||
warnings.warn(
|
||||
"The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
|
||||
"The `on_event` decorator is deprecated, and will be removed in version 1.0.0. "
|
||||
"Refer to https://www.starlette.io/lifespan/ for recommended approach.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
|
||||
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
|
||||
self.add_event_handler(event_type, func)
|
||||
return func
|
||||
|
||||
|
@@ -10,7 +10,7 @@ from starlette.routing import BaseRoute, Host, Mount, Route
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ModuleNotFoundError: # pragma: nocover
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
yaml = None # type: ignore[assignment]
|
||||
|
||||
|
||||
@@ -19,9 +19,7 @@ class OpenAPIResponse(Response):
|
||||
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
|
||||
assert isinstance(
|
||||
content, dict
|
||||
), "The schema passed to OpenAPIResponse should be a dictionary."
|
||||
assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary."
|
||||
return yaml.dump(content, default_flow_style=False).encode("utf-8")
|
||||
|
||||
|
||||
@@ -31,6 +29,9 @@ class EndpointInfo(typing.NamedTuple):
|
||||
func: typing.Callable[..., typing.Any]
|
||||
|
||||
|
||||
_remove_converter_pattern = re.compile(r":\w+}")
|
||||
|
||||
|
||||
class BaseSchemaGenerator:
|
||||
def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
@@ -73,9 +74,7 @@ class BaseSchemaGenerator:
|
||||
for method in route.methods or ["GET"]:
|
||||
if method == "HEAD":
|
||||
continue
|
||||
endpoints_info.append(
|
||||
EndpointInfo(path, method.lower(), route.endpoint)
|
||||
)
|
||||
endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint))
|
||||
else:
|
||||
path = self._remove_converter(route.path)
|
||||
for method in ["get", "post", "put", "patch", "delete", "options"]:
|
||||
@@ -93,11 +92,9 @@ class BaseSchemaGenerator:
|
||||
Route("/users/{id:int}", endpoint=get_user, methods=["GET"])
|
||||
Should be represented as `/users/{id}` in the OpenAPI schema.
|
||||
"""
|
||||
return re.sub(r":\w+}", "}", path)
|
||||
return _remove_converter_pattern.sub("}", path)
|
||||
|
||||
def parse_docstring(
|
||||
self, func_or_method: typing.Callable[..., typing.Any]
|
||||
) -> dict[str, typing.Any]:
|
||||
def parse_docstring(self, func_or_method: typing.Callable[..., typing.Any]) -> dict[str, typing.Any]:
|
||||
"""
|
||||
Given a function, parse the docstring as YAML and return a dictionary of info.
|
||||
"""
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import importlib.util
|
||||
import os
|
||||
import stat
|
||||
@@ -31,11 +32,7 @@ class NotModifiedResponse(Response):
|
||||
def __init__(self, headers: Headers):
|
||||
super().__init__(
|
||||
status_code=304,
|
||||
headers={
|
||||
name: value
|
||||
for name, value in headers.items()
|
||||
if name in self.NOT_MODIFIED_HEADERS
|
||||
},
|
||||
headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS},
|
||||
)
|
||||
|
||||
|
||||
@@ -79,9 +76,7 @@ class StaticFiles:
|
||||
spec = importlib.util.find_spec(package)
|
||||
assert spec is not None, f"Package {package!r} could not be found."
|
||||
assert spec.origin is not None, f"Package {package!r} could not be found."
|
||||
package_directory = os.path.normpath(
|
||||
os.path.join(spec.origin, "..", statics_dir)
|
||||
)
|
||||
package_directory = os.path.normpath(os.path.join(spec.origin, "..", statics_dir))
|
||||
assert os.path.isdir(
|
||||
package_directory
|
||||
), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
|
||||
@@ -109,7 +104,7 @@ class StaticFiles:
|
||||
with OS specific path separators, and any '..', '.' components removed.
|
||||
"""
|
||||
route_path = get_route_path(scope)
|
||||
return os.path.normpath(os.path.join(*route_path.split("/"))) # noqa: E501
|
||||
return os.path.normpath(os.path.join(*route_path.split("/")))
|
||||
|
||||
async def get_response(self, path: str, scope: Scope) -> Response:
|
||||
"""
|
||||
@@ -119,13 +114,15 @@ class StaticFiles:
|
||||
raise HTTPException(status_code=405)
|
||||
|
||||
try:
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(
|
||||
self.lookup_path, path
|
||||
)
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, path)
|
||||
except PermissionError:
|
||||
raise HTTPException(status_code=401)
|
||||
except OSError:
|
||||
raise
|
||||
except OSError as exc:
|
||||
# Filename is too long, so it can't be a valid static file.
|
||||
if exc.errno == errno.ENAMETOOLONG:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
raise exc
|
||||
|
||||
if stat_result and stat.S_ISREG(stat_result.st_mode):
|
||||
# We have a static file to serve.
|
||||
@@ -135,9 +132,7 @@ class StaticFiles:
|
||||
# We're in HTML mode, and have got a directory URL.
|
||||
# Check if we have 'index.html' file to serve.
|
||||
index_path = os.path.join(path, "index.html")
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(
|
||||
self.lookup_path, index_path
|
||||
)
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, index_path)
|
||||
if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
|
||||
if not scope["path"].endswith("/"):
|
||||
# Directory URLs should redirect to always end in "/".
|
||||
@@ -148,9 +143,7 @@ class StaticFiles:
|
||||
|
||||
if self.html:
|
||||
# Check for '404.html' if we're in HTML mode.
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(
|
||||
self.lookup_path, "404.html"
|
||||
)
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, "404.html")
|
||||
if stat_result and stat.S_ISREG(stat_result.st_mode):
|
||||
return FileResponse(full_path, stat_result=stat_result, status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
@@ -162,10 +155,9 @@ class StaticFiles:
|
||||
full_path = os.path.abspath(joined_path)
|
||||
else:
|
||||
full_path = os.path.realpath(joined_path)
|
||||
directory = os.path.realpath(directory)
|
||||
if os.path.commonpath([full_path, directory]) != directory:
|
||||
# Don't allow misbehaving clients to break out of the static files
|
||||
# directory.
|
||||
directory = os.path.realpath(directory)
|
||||
if os.path.commonpath([full_path, directory]) != str(directory):
|
||||
# Don't allow misbehaving clients to break out of the static files directory.
|
||||
continue
|
||||
try:
|
||||
return full_path, os.stat(full_path)
|
||||
@@ -182,9 +174,7 @@ class StaticFiles:
|
||||
) -> Response:
|
||||
request_headers = Headers(scope=scope)
|
||||
|
||||
response = FileResponse(
|
||||
full_path, status_code=status_code, stat_result=stat_result
|
||||
)
|
||||
response = FileResponse(full_path, status_code=status_code, stat_result=stat_result)
|
||||
if self.is_not_modified(response.headers, request_headers):
|
||||
return NotModifiedResponse(response.headers)
|
||||
return response
|
||||
@@ -201,17 +191,11 @@ class StaticFiles:
|
||||
try:
|
||||
stat_result = await anyio.to_thread.run_sync(os.stat, self.directory)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(
|
||||
f"StaticFiles directory '{self.directory}' does not exist."
|
||||
)
|
||||
raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.")
|
||||
if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)):
|
||||
raise RuntimeError(
|
||||
f"StaticFiles path '{self.directory}' is not a directory."
|
||||
)
|
||||
raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.")
|
||||
|
||||
def is_not_modified(
|
||||
self, response_headers: Headers, request_headers: Headers
|
||||
) -> bool:
|
||||
def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool:
|
||||
"""
|
||||
Given the request and response headers, return `True` if an HTTP
|
||||
"Not Modified" response could be returned instead.
|
||||
@@ -227,11 +211,7 @@ class StaticFiles:
|
||||
try:
|
||||
if_modified_since = parsedate(request_headers["if-modified-since"])
|
||||
last_modified = parsedate(response_headers["last-modified"])
|
||||
if (
|
||||
if_modified_since is not None
|
||||
and last_modified is not None
|
||||
and if_modified_since >= last_modified
|
||||
):
|
||||
if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified:
|
||||
return True
|
||||
except KeyError:
|
||||
pass
|
||||
|
@@ -5,91 +5,9 @@ https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml
|
||||
|
||||
And RFC 2324 - https://tools.ietf.org/html/rfc2324
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
__all__ = (
|
||||
"HTTP_100_CONTINUE",
|
||||
"HTTP_101_SWITCHING_PROTOCOLS",
|
||||
"HTTP_102_PROCESSING",
|
||||
"HTTP_103_EARLY_HINTS",
|
||||
"HTTP_200_OK",
|
||||
"HTTP_201_CREATED",
|
||||
"HTTP_202_ACCEPTED",
|
||||
"HTTP_203_NON_AUTHORITATIVE_INFORMATION",
|
||||
"HTTP_204_NO_CONTENT",
|
||||
"HTTP_205_RESET_CONTENT",
|
||||
"HTTP_206_PARTIAL_CONTENT",
|
||||
"HTTP_207_MULTI_STATUS",
|
||||
"HTTP_208_ALREADY_REPORTED",
|
||||
"HTTP_226_IM_USED",
|
||||
"HTTP_300_MULTIPLE_CHOICES",
|
||||
"HTTP_301_MOVED_PERMANENTLY",
|
||||
"HTTP_302_FOUND",
|
||||
"HTTP_303_SEE_OTHER",
|
||||
"HTTP_304_NOT_MODIFIED",
|
||||
"HTTP_305_USE_PROXY",
|
||||
"HTTP_306_RESERVED",
|
||||
"HTTP_307_TEMPORARY_REDIRECT",
|
||||
"HTTP_308_PERMANENT_REDIRECT",
|
||||
"HTTP_400_BAD_REQUEST",
|
||||
"HTTP_401_UNAUTHORIZED",
|
||||
"HTTP_402_PAYMENT_REQUIRED",
|
||||
"HTTP_403_FORBIDDEN",
|
||||
"HTTP_404_NOT_FOUND",
|
||||
"HTTP_405_METHOD_NOT_ALLOWED",
|
||||
"HTTP_406_NOT_ACCEPTABLE",
|
||||
"HTTP_407_PROXY_AUTHENTICATION_REQUIRED",
|
||||
"HTTP_408_REQUEST_TIMEOUT",
|
||||
"HTTP_409_CONFLICT",
|
||||
"HTTP_410_GONE",
|
||||
"HTTP_411_LENGTH_REQUIRED",
|
||||
"HTTP_412_PRECONDITION_FAILED",
|
||||
"HTTP_413_REQUEST_ENTITY_TOO_LARGE",
|
||||
"HTTP_414_REQUEST_URI_TOO_LONG",
|
||||
"HTTP_415_UNSUPPORTED_MEDIA_TYPE",
|
||||
"HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE",
|
||||
"HTTP_417_EXPECTATION_FAILED",
|
||||
"HTTP_418_IM_A_TEAPOT",
|
||||
"HTTP_421_MISDIRECTED_REQUEST",
|
||||
"HTTP_422_UNPROCESSABLE_ENTITY",
|
||||
"HTTP_423_LOCKED",
|
||||
"HTTP_424_FAILED_DEPENDENCY",
|
||||
"HTTP_425_TOO_EARLY",
|
||||
"HTTP_426_UPGRADE_REQUIRED",
|
||||
"HTTP_428_PRECONDITION_REQUIRED",
|
||||
"HTTP_429_TOO_MANY_REQUESTS",
|
||||
"HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE",
|
||||
"HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS",
|
||||
"HTTP_500_INTERNAL_SERVER_ERROR",
|
||||
"HTTP_501_NOT_IMPLEMENTED",
|
||||
"HTTP_502_BAD_GATEWAY",
|
||||
"HTTP_503_SERVICE_UNAVAILABLE",
|
||||
"HTTP_504_GATEWAY_TIMEOUT",
|
||||
"HTTP_505_HTTP_VERSION_NOT_SUPPORTED",
|
||||
"HTTP_506_VARIANT_ALSO_NEGOTIATES",
|
||||
"HTTP_507_INSUFFICIENT_STORAGE",
|
||||
"HTTP_508_LOOP_DETECTED",
|
||||
"HTTP_510_NOT_EXTENDED",
|
||||
"HTTP_511_NETWORK_AUTHENTICATION_REQUIRED",
|
||||
"WS_1000_NORMAL_CLOSURE",
|
||||
"WS_1001_GOING_AWAY",
|
||||
"WS_1002_PROTOCOL_ERROR",
|
||||
"WS_1003_UNSUPPORTED_DATA",
|
||||
"WS_1005_NO_STATUS_RCVD",
|
||||
"WS_1006_ABNORMAL_CLOSURE",
|
||||
"WS_1007_INVALID_FRAME_PAYLOAD_DATA",
|
||||
"WS_1008_POLICY_VIOLATION",
|
||||
"WS_1009_MESSAGE_TOO_BIG",
|
||||
"WS_1010_MANDATORY_EXT",
|
||||
"WS_1011_INTERNAL_ERROR",
|
||||
"WS_1012_SERVICE_RESTART",
|
||||
"WS_1013_TRY_AGAIN_LATER",
|
||||
"WS_1014_BAD_GATEWAY",
|
||||
"WS_1015_TLS_HANDSHAKE",
|
||||
)
|
||||
|
||||
HTTP_100_CONTINUE = 100
|
||||
HTTP_101_SWITCHING_PROTOCOLS = 101
|
||||
HTTP_102_PROCESSING = 102
|
||||
@@ -175,26 +93,3 @@ WS_1012_SERVICE_RESTART = 1012
|
||||
WS_1013_TRY_AGAIN_LATER = 1013
|
||||
WS_1014_BAD_GATEWAY = 1014
|
||||
WS_1015_TLS_HANDSHAKE = 1015
|
||||
|
||||
|
||||
__deprecated__ = {"WS_1004_NO_STATUS_RCVD": 1004, "WS_1005_ABNORMAL_CLOSURE": 1005}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> int:
|
||||
deprecation_changes = {
|
||||
"WS_1004_NO_STATUS_RCVD": "WS_1005_NO_STATUS_RCVD",
|
||||
"WS_1005_ABNORMAL_CLOSURE": "WS_1006_ABNORMAL_CLOSURE",
|
||||
}
|
||||
deprecated = __deprecated__.get(name)
|
||||
if deprecated:
|
||||
warnings.warn(
|
||||
f"'{name}' is deprecated. Use '{deprecation_changes[name]}' instead.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
return deprecated
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(list(__all__) + list(__deprecated__.keys())) # pragma: no cover
|
||||
|
@@ -19,9 +19,9 @@ try:
|
||||
# adding a type ignore for mypy to let us access an attribute that may not exist
|
||||
if hasattr(jinja2, "pass_context"):
|
||||
pass_context = jinja2.pass_context
|
||||
else: # pragma: nocover
|
||||
else: # pragma: no cover
|
||||
pass_context = jinja2.contextfunction # type: ignore[attr-defined]
|
||||
except ModuleNotFoundError: # pragma: nocover
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
jinja2 = None # type: ignore[assignment]
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ class _TemplateResponse(HTMLResponse):
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
request = self.context.get("request", {})
|
||||
extensions = request.get("extensions", {})
|
||||
if "http.response.debug" in extensions:
|
||||
if "http.response.debug" in extensions: # pragma: no branch
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.debug",
|
||||
@@ -66,58 +66,46 @@ class Jinja2Templates:
|
||||
@typing.overload
|
||||
def __init__(
|
||||
self,
|
||||
directory: str
|
||||
| PathLike[typing.AnyStr]
|
||||
| typing.Sequence[str | PathLike[typing.AnyStr]],
|
||||
directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]],
|
||||
*,
|
||||
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
|
||||
| None = None,
|
||||
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
|
||||
**env_options: typing.Any,
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
@typing.overload
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
env: jinja2.Environment,
|
||||
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
|
||||
| None = None,
|
||||
) -> None:
|
||||
...
|
||||
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
|
||||
) -> None: ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
directory: str
|
||||
| PathLike[typing.AnyStr]
|
||||
| typing.Sequence[str | PathLike[typing.AnyStr]]
|
||||
| None = None,
|
||||
directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]] | None = None,
|
||||
*,
|
||||
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
|
||||
| None = None,
|
||||
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
|
||||
env: jinja2.Environment | None = None,
|
||||
**env_options: typing.Any,
|
||||
) -> None:
|
||||
if env_options:
|
||||
warnings.warn(
|
||||
"Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.", # noqa: E501
|
||||
"Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
|
||||
assert directory or env, "either 'directory' or 'env' arguments must be passed"
|
||||
assert bool(directory) ^ bool(env), "either 'directory' or 'env' arguments must be passed"
|
||||
self.context_processors = context_processors or []
|
||||
if directory is not None:
|
||||
self.env = self._create_env(directory, **env_options)
|
||||
elif env is not None:
|
||||
elif env is not None: # pragma: no branch
|
||||
self.env = env
|
||||
|
||||
self._setup_env_defaults(self.env)
|
||||
|
||||
def _create_env(
|
||||
self,
|
||||
directory: str
|
||||
| PathLike[typing.AnyStr]
|
||||
| typing.Sequence[str | PathLike[typing.AnyStr]],
|
||||
directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]],
|
||||
**env_options: typing.Any,
|
||||
) -> jinja2.Environment:
|
||||
loader = jinja2.FileSystemLoader(directory)
|
||||
@@ -129,7 +117,7 @@ class Jinja2Templates:
|
||||
def _setup_env_defaults(self, env: jinja2.Environment) -> None:
|
||||
@pass_context
|
||||
def url_for(
|
||||
context: typing.Dict[str, typing.Any],
|
||||
context: dict[str, typing.Any],
|
||||
name: str,
|
||||
/,
|
||||
**path_params: typing.Any,
|
||||
@@ -152,8 +140,7 @@ class Jinja2Templates:
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
) -> _TemplateResponse:
|
||||
...
|
||||
) -> _TemplateResponse: ...
|
||||
|
||||
@typing.overload
|
||||
def TemplateResponse(
|
||||
@@ -168,25 +155,19 @@ class Jinja2Templates:
|
||||
# Deprecated usage
|
||||
...
|
||||
|
||||
def TemplateResponse(
|
||||
self, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> _TemplateResponse:
|
||||
def TemplateResponse(self, *args: typing.Any, **kwargs: typing.Any) -> _TemplateResponse:
|
||||
if args:
|
||||
if isinstance(
|
||||
args[0], str
|
||||
): # the first argument is template name (old style)
|
||||
if isinstance(args[0], str): # the first argument is template name (old style)
|
||||
warnings.warn(
|
||||
"The `name` is not the first parameter anymore. "
|
||||
"The first parameter should be the `Request` instance.\n"
|
||||
'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.', # noqa: E501
|
||||
'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.',
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
name = args[0]
|
||||
context = args[1] if len(args) > 1 else kwargs.get("context", {})
|
||||
status_code = (
|
||||
args[2] if len(args) > 2 else kwargs.get("status_code", 200)
|
||||
)
|
||||
status_code = args[2] if len(args) > 2 else kwargs.get("status_code", 200)
|
||||
headers = args[2] if len(args) > 2 else kwargs.get("headers")
|
||||
media_type = args[3] if len(args) > 3 else kwargs.get("media_type")
|
||||
background = args[4] if len(args) > 4 else kwargs.get("background")
|
||||
@@ -198,9 +179,7 @@ class Jinja2Templates:
|
||||
request = args[0]
|
||||
name = args[1] if len(args) > 1 else kwargs["name"]
|
||||
context = args[2] if len(args) > 2 else kwargs.get("context", {})
|
||||
status_code = (
|
||||
args[3] if len(args) > 3 else kwargs.get("status_code", 200)
|
||||
)
|
||||
status_code = args[3] if len(args) > 3 else kwargs.get("status_code", 200)
|
||||
headers = args[4] if len(args) > 4 else kwargs.get("headers")
|
||||
media_type = args[5] if len(args) > 5 else kwargs.get("media_type")
|
||||
background = args[6] if len(args) > 6 else kwargs.get("background")
|
||||
@@ -208,7 +187,7 @@ class Jinja2Templates:
|
||||
if "request" not in kwargs:
|
||||
warnings.warn(
|
||||
"The `TemplateResponse` now requires the `request` argument.\n"
|
||||
'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.', # noqa: E501
|
||||
'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.',
|
||||
DeprecationWarning,
|
||||
)
|
||||
if "request" not in kwargs.get("context", {}):
|
||||
|
@@ -5,19 +5,15 @@ import inspect
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import queue
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from concurrent.futures import Future
|
||||
from functools import cached_property
|
||||
from types import GeneratorType
|
||||
from urllib.parse import unquote, urljoin
|
||||
|
||||
import anyio
|
||||
import anyio.abc
|
||||
import anyio.from_thread
|
||||
from anyio.abc import ObjectReceiveStream, ObjectSendStream
|
||||
from anyio.streams.stapled import StapledObjectStream
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
@@ -37,16 +33,14 @@ except ModuleNotFoundError: # pragma: no cover
|
||||
"You can install this with:\n"
|
||||
" $ pip install httpx\n"
|
||||
)
|
||||
_PortalFactoryType = typing.Callable[
|
||||
[], typing.ContextManager[anyio.abc.BlockingPortal]
|
||||
]
|
||||
_PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]]
|
||||
|
||||
ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
|
||||
ASGI2App = typing.Callable[[Scope], ASGIInstance]
|
||||
ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
|
||||
|
||||
|
||||
_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str]]]
|
||||
_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]]
|
||||
|
||||
|
||||
def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
|
||||
@@ -78,6 +72,16 @@ class _Upgrade(Exception):
|
||||
self.session = session
|
||||
|
||||
|
||||
class WebSocketDenialResponse( # type: ignore[misc]
|
||||
httpx.Response,
|
||||
WebSocketDisconnect,
|
||||
):
|
||||
"""
|
||||
A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
|
||||
`WebSocket` is closed before being accepted with a `send_denial_response()`.
|
||||
"""
|
||||
|
||||
|
||||
class WebSocketTestSession:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -89,81 +93,60 @@ class WebSocketTestSession:
|
||||
self.scope = scope
|
||||
self.accepted_subprotocol = None
|
||||
self.portal_factory = portal_factory
|
||||
self._receive_queue: queue.Queue[Message] = queue.Queue()
|
||||
self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
|
||||
self.extra_headers = None
|
||||
|
||||
def __enter__(self) -> WebSocketTestSession:
|
||||
self.exit_stack = contextlib.ExitStack()
|
||||
self.portal = self.exit_stack.enter_context(self.portal_factory())
|
||||
|
||||
try:
|
||||
_: Future[None] = self.portal.start_task_soon(self._run)
|
||||
with contextlib.ExitStack() as stack:
|
||||
self.portal = portal = stack.enter_context(self.portal_factory())
|
||||
fut, cs = portal.start_task(self._run)
|
||||
stack.callback(fut.result)
|
||||
stack.callback(portal.call, cs.cancel)
|
||||
self.send({"type": "websocket.connect"})
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
except Exception:
|
||||
self.exit_stack.close()
|
||||
raise
|
||||
self.accepted_subprotocol = message.get("subprotocol", None)
|
||||
self.extra_headers = message.get("headers", None)
|
||||
return self
|
||||
self.accepted_subprotocol = message.get("subprotocol", None)
|
||||
self.extra_headers = message.get("headers", None)
|
||||
stack.callback(self.close, 1000)
|
||||
self.exit_stack = stack.pop_all()
|
||||
return self
|
||||
|
||||
@cached_property
|
||||
def should_close(self) -> anyio.Event:
|
||||
return anyio.Event()
|
||||
def __exit__(self, *args: typing.Any) -> bool | None:
|
||||
return self.exit_stack.__exit__(*args)
|
||||
|
||||
async def _notify_close(self) -> None:
|
||||
self.should_close.set()
|
||||
|
||||
def __exit__(self, *args: typing.Any) -> None:
|
||||
try:
|
||||
self.close(1000)
|
||||
finally:
|
||||
self.portal.start_task_soon(self._notify_close)
|
||||
self.exit_stack.close()
|
||||
while not self._send_queue.empty():
|
||||
message = self._send_queue.get()
|
||||
if isinstance(message, BaseException):
|
||||
raise message
|
||||
|
||||
async def _run(self) -> None:
|
||||
async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
|
||||
"""
|
||||
The sub-thread in which the websocket session runs.
|
||||
"""
|
||||
send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
|
||||
send_tx, send_rx = send
|
||||
receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
|
||||
receive_tx, receive_rx = receive
|
||||
with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
|
||||
self._receive_tx = receive_tx
|
||||
self._send_rx = send_rx
|
||||
task_status.started(cs)
|
||||
await self.app(self.scope, receive_rx.receive, send_tx.send)
|
||||
|
||||
async def run_app(tg: anyio.abc.TaskGroup) -> None:
|
||||
try:
|
||||
await self.app(self.scope, self._asgi_receive, self._asgi_send)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
...
|
||||
except BaseException as exc:
|
||||
self._send_queue.put(exc)
|
||||
raise
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(run_app, tg)
|
||||
await self.should_close.wait()
|
||||
tg.cancel_scope.cancel()
|
||||
|
||||
async def _asgi_receive(self) -> Message:
|
||||
while self._receive_queue.empty():
|
||||
await anyio.sleep(0)
|
||||
return self._receive_queue.get()
|
||||
|
||||
async def _asgi_send(self, message: Message) -> None:
|
||||
self._send_queue.put(message)
|
||||
# wait for cs.cancel to be called before closing streams
|
||||
await anyio.sleep_forever()
|
||||
|
||||
def _raise_on_close(self, message: Message) -> None:
|
||||
if message["type"] == "websocket.close":
|
||||
raise WebSocketDisconnect(
|
||||
message.get("code", 1000), message.get("reason", "")
|
||||
)
|
||||
raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
|
||||
elif message["type"] == "websocket.http.response.start":
|
||||
status_code: int = message["status"]
|
||||
headers: list[tuple[bytes, bytes]] = message["headers"]
|
||||
body: list[bytes] = []
|
||||
while True:
|
||||
message = self.receive()
|
||||
assert message["type"] == "websocket.http.response.body"
|
||||
body.append(message["body"])
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
|
||||
|
||||
def send(self, message: Message) -> None:
|
||||
self._receive_queue.put(message)
|
||||
self.portal.call(self._receive_tx.send, message)
|
||||
|
||||
def send_text(self, data: str) -> None:
|
||||
self.send({"type": "websocket.receive", "text": data})
|
||||
@@ -171,9 +154,7 @@ class WebSocketTestSession:
|
||||
def send_bytes(self, data: bytes) -> None:
|
||||
self.send({"type": "websocket.receive", "bytes": data})
|
||||
|
||||
def send_json(
|
||||
self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text"
|
||||
) -> None:
|
||||
def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None:
|
||||
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
|
||||
if mode == "text":
|
||||
self.send({"type": "websocket.receive", "text": text})
|
||||
@@ -184,10 +165,7 @@ class WebSocketTestSession:
|
||||
self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
|
||||
def receive(self) -> Message:
|
||||
message = self._send_queue.get()
|
||||
if isinstance(message, BaseException):
|
||||
raise message
|
||||
return message
|
||||
return self.portal.call(self._send_rx.receive)
|
||||
|
||||
def receive_text(self) -> str:
|
||||
message = self.receive()
|
||||
@@ -199,9 +177,7 @@ class WebSocketTestSession:
|
||||
self._raise_on_close(message)
|
||||
return typing.cast(bytes, message["bytes"])
|
||||
|
||||
def receive_json(
|
||||
self, mode: typing.Literal["text", "binary"] = "text"
|
||||
) -> typing.Any:
|
||||
def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any:
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
if mode == "text":
|
||||
@@ -219,6 +195,7 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
raise_server_exceptions: bool = True,
|
||||
root_path: str = "",
|
||||
*,
|
||||
client: tuple[str, int],
|
||||
app_state: dict[str, typing.Any],
|
||||
) -> None:
|
||||
self.app = app
|
||||
@@ -226,6 +203,7 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
self.root_path = root_path
|
||||
self.portal_factory = portal_factory
|
||||
self.app_state = app_state
|
||||
self.client = client
|
||||
|
||||
def handle_request(self, request: httpx.Request) -> httpx.Response:
|
||||
scheme = request.url.scheme
|
||||
@@ -252,10 +230,7 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
headers = [(b"host", (f"{host}:{port}").encode())]
|
||||
|
||||
# Include other request headers.
|
||||
headers += [
|
||||
(key.lower().encode(), value.encode())
|
||||
for key, value in request.headers.multi_items()
|
||||
]
|
||||
headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
|
||||
|
||||
scope: dict[str, typing.Any]
|
||||
|
||||
@@ -268,15 +243,16 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
scope = {
|
||||
"type": "websocket",
|
||||
"path": unquote(path),
|
||||
"raw_path": raw_path,
|
||||
"raw_path": raw_path.split(b"?", 1)[0],
|
||||
"root_path": self.root_path,
|
||||
"scheme": scheme,
|
||||
"query_string": query.encode(),
|
||||
"headers": headers,
|
||||
"client": None,
|
||||
"client": self.client,
|
||||
"server": [host, port],
|
||||
"subprotocols": subprotocols,
|
||||
"state": self.app_state.copy(),
|
||||
"extensions": {"websocket.http.response": {}},
|
||||
}
|
||||
session = WebSocketTestSession(self.app, scope, self.portal_factory)
|
||||
raise _Upgrade(session)
|
||||
@@ -286,12 +262,12 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
"http_version": "1.1",
|
||||
"method": request.method,
|
||||
"path": unquote(path),
|
||||
"raw_path": raw_path,
|
||||
"raw_path": raw_path.split(b"?", 1)[0],
|
||||
"root_path": self.root_path,
|
||||
"scheme": scheme,
|
||||
"query_string": query.encode(),
|
||||
"headers": headers,
|
||||
"client": None,
|
||||
"client": self.client,
|
||||
"server": [host, port],
|
||||
"extensions": {"http.response.debug": {}},
|
||||
"state": self.app_state.copy(),
|
||||
@@ -336,22 +312,13 @@ class _TestClientTransport(httpx.BaseTransport):
|
||||
nonlocal raw_kwargs, response_started, template, context
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
assert (
|
||||
not response_started
|
||||
), 'Received multiple "http.response.start" messages.'
|
||||
assert not response_started, 'Received multiple "http.response.start" messages.'
|
||||
raw_kwargs["status_code"] = message["status"]
|
||||
raw_kwargs["headers"] = [
|
||||
(key.decode(), value.decode())
|
||||
for key, value in message.get("headers", [])
|
||||
]
|
||||
raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
|
||||
response_started = True
|
||||
elif message["type"] == "http.response.body":
|
||||
assert (
|
||||
response_started
|
||||
), 'Received "http.response.body" without "http.response.start".'
|
||||
assert (
|
||||
not response_complete.is_set()
|
||||
), 'Received "http.response.body" after response completed.'
|
||||
assert response_started, 'Received "http.response.body" without "http.response.start".'
|
||||
assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
if request.method != "HEAD":
|
||||
@@ -405,10 +372,9 @@ class TestClient(httpx.Client):
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
follow_redirects: bool = True,
|
||||
client: tuple[str, int] = ("testclient", 50000),
|
||||
) -> None:
|
||||
self.async_backend = _AsyncBackend(
|
||||
backend=backend, backend_options=backend_options or {}
|
||||
)
|
||||
self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
|
||||
if _is_asgi3(app):
|
||||
asgi_app = app
|
||||
else:
|
||||
@@ -422,12 +388,12 @@ class TestClient(httpx.Client):
|
||||
raise_server_exceptions=raise_server_exceptions,
|
||||
root_path=root_path,
|
||||
app_state=self.app_state,
|
||||
client=client,
|
||||
)
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers.setdefault("user-agent", "testclient")
|
||||
super().__init__(
|
||||
app=self.app,
|
||||
base_url=base_url,
|
||||
headers=headers,
|
||||
transport=transport,
|
||||
@@ -440,32 +406,9 @@ class TestClient(httpx.Client):
|
||||
if self.portal is not None:
|
||||
yield self.portal
|
||||
else:
|
||||
with anyio.from_thread.start_blocking_portal(
|
||||
**self.async_backend
|
||||
) as portal:
|
||||
with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
|
||||
yield portal
|
||||
|
||||
def _choose_redirect_arg(
|
||||
self, follow_redirects: bool | None, allow_redirects: bool | None
|
||||
) -> bool | httpx._client.UseClientDefault:
|
||||
redirect: bool | httpx._client.UseClientDefault = (
|
||||
httpx._client.USE_CLIENT_DEFAULT
|
||||
)
|
||||
if allow_redirects is not None:
|
||||
message = (
|
||||
"The `allow_redirects` argument is deprecated. "
|
||||
"Use `follow_redirects` instead."
|
||||
)
|
||||
warnings.warn(message, DeprecationWarning)
|
||||
redirect = allow_redirects
|
||||
if follow_redirects is not None:
|
||||
redirect = follow_redirects
|
||||
elif allow_redirects is not None and follow_redirects is not None:
|
||||
raise RuntimeError( # pragma: no cover
|
||||
"Cannot use both `allow_redirects` and `follow_redirects`."
|
||||
)
|
||||
return redirect
|
||||
|
||||
def request( # type: ignore[override]
|
||||
self,
|
||||
method: str,
|
||||
@@ -478,16 +421,12 @@ class TestClient(httpx.Client):
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes
|
||||
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | None = None,
|
||||
allow_redirects: bool | None = None,
|
||||
timeout: httpx._types.TimeoutTypes
|
||||
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, typing.Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
url = self._merge_url(url)
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().request(
|
||||
method,
|
||||
url,
|
||||
@@ -499,7 +438,7 @@ class TestClient(httpx.Client):
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -511,22 +450,18 @@ class TestClient(httpx.Client):
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes
|
||||
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | None = None,
|
||||
allow_redirects: bool | None = None,
|
||||
timeout: httpx._types.TimeoutTypes
|
||||
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, typing.Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().get(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -538,22 +473,18 @@ class TestClient(httpx.Client):
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes
|
||||
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | None = None,
|
||||
allow_redirects: bool | None = None,
|
||||
timeout: httpx._types.TimeoutTypes
|
||||
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, typing.Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().options(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -565,22 +496,18 @@ class TestClient(httpx.Client):
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes
|
||||
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | None = None,
|
||||
allow_redirects: bool | None = None,
|
||||
timeout: httpx._types.TimeoutTypes
|
||||
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, typing.Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().head(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -596,15 +523,11 @@ class TestClient(httpx.Client):
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes
|
||||
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | None = None,
|
||||
allow_redirects: bool | None = None,
|
||||
timeout: httpx._types.TimeoutTypes
|
||||
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, typing.Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().post(
|
||||
url,
|
||||
content=content,
|
||||
@@ -615,7 +538,7 @@ class TestClient(httpx.Client):
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -631,15 +554,11 @@ class TestClient(httpx.Client):
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes
|
||||
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | None = None,
|
||||
allow_redirects: bool | None = None,
|
||||
timeout: httpx._types.TimeoutTypes
|
||||
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, typing.Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().put(
|
||||
url,
|
||||
content=content,
|
||||
@@ -650,7 +569,7 @@ class TestClient(httpx.Client):
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -666,15 +585,11 @@ class TestClient(httpx.Client):
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes
|
||||
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | None = None,
|
||||
allow_redirects: bool | None = None,
|
||||
timeout: httpx._types.TimeoutTypes
|
||||
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, typing.Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().patch(
|
||||
url,
|
||||
content=content,
|
||||
@@ -685,7 +600,7 @@ class TestClient(httpx.Client):
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -697,22 +612,18 @@ class TestClient(httpx.Client):
|
||||
params: httpx._types.QueryParamTypes | None = None,
|
||||
headers: httpx._types.HeaderTypes | None = None,
|
||||
cookies: httpx._types.CookieTypes | None = None,
|
||||
auth: httpx._types.AuthTypes
|
||||
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | None = None,
|
||||
allow_redirects: bool | None = None,
|
||||
timeout: httpx._types.TimeoutTypes
|
||||
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||||
extensions: dict[str, typing.Any] | None = None,
|
||||
) -> httpx.Response:
|
||||
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||||
return super().delete(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
follow_redirects=redirect,
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
@@ -742,22 +653,22 @@ class TestClient(httpx.Client):
|
||||
|
||||
def __enter__(self) -> TestClient:
|
||||
with contextlib.ExitStack() as stack:
|
||||
self.portal = portal = stack.enter_context(
|
||||
anyio.from_thread.start_blocking_portal(**self.async_backend)
|
||||
)
|
||||
self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
|
||||
|
||||
@stack.callback
|
||||
def reset_portal() -> None:
|
||||
self.portal = None
|
||||
|
||||
send1: ObjectSendStream[typing.MutableMapping[str, typing.Any] | None]
|
||||
receive1: ObjectReceiveStream[typing.MutableMapping[str, typing.Any] | None]
|
||||
send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
|
||||
receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
|
||||
send1, receive1 = anyio.create_memory_object_stream(math.inf)
|
||||
send2, receive2 = anyio.create_memory_object_stream(math.inf)
|
||||
self.stream_send = StapledObjectStream(send1, receive1)
|
||||
self.stream_receive = StapledObjectStream(send2, receive2)
|
||||
send: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any] | None] = (
|
||||
anyio.create_memory_object_stream(math.inf)
|
||||
)
|
||||
receive: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]] = (
|
||||
anyio.create_memory_object_stream(math.inf)
|
||||
)
|
||||
for channel in (*send, *receive):
|
||||
stack.callback(channel.close)
|
||||
self.stream_send = StapledObjectStream(*send)
|
||||
self.stream_receive = StapledObjectStream(*receive)
|
||||
self.task = portal.start_task_soon(self.lifespan)
|
||||
portal.call(self.wait_startup)
|
||||
|
||||
@@ -803,12 +714,11 @@ class TestClient(httpx.Client):
|
||||
self.task.result()
|
||||
return message
|
||||
|
||||
async with self.stream_send:
|
||||
await self.stream_receive.send({"type": "lifespan.shutdown"})
|
||||
message = await receive()
|
||||
assert message["type"] in (
|
||||
"lifespan.shutdown.complete",
|
||||
"lifespan.shutdown.failed",
|
||||
)
|
||||
if message["type"] == "lifespan.shutdown.failed":
|
||||
await receive()
|
||||
await self.stream_receive.send({"type": "lifespan.shutdown"})
|
||||
message = await receive()
|
||||
assert message["type"] in (
|
||||
"lifespan.shutdown.complete",
|
||||
"lifespan.shutdown.failed",
|
||||
)
|
||||
if message["type"] == "lifespan.shutdown.failed":
|
||||
await receive()
|
||||
|
@@ -16,15 +16,9 @@ Send = typing.Callable[[Message], typing.Awaitable[None]]
|
||||
ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
|
||||
|
||||
StatelessLifespan = typing.Callable[[AppType], typing.AsyncContextManager[None]]
|
||||
StatefulLifespan = typing.Callable[
|
||||
[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]
|
||||
]
|
||||
StatefulLifespan = typing.Callable[[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]]
|
||||
Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]
|
||||
|
||||
HTTPExceptionHandler = typing.Callable[
|
||||
["Request", Exception], "Response | typing.Awaitable[Response]"
|
||||
]
|
||||
WebSocketExceptionHandler = typing.Callable[
|
||||
["WebSocket", Exception], typing.Awaitable[None]
|
||||
]
|
||||
HTTPExceptionHandler = typing.Callable[["Request", Exception], "Response | typing.Awaitable[Response]"]
|
||||
WebSocketExceptionHandler = typing.Callable[["WebSocket", Exception], typing.Awaitable[None]]
|
||||
ExceptionHandler = typing.Union[HTTPExceptionHandler, WebSocketExceptionHandler]
|
||||
|
@@ -5,6 +5,7 @@ import json
|
||||
import typing
|
||||
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.responses import Response
|
||||
from starlette.types import Message, Receive, Scope, Send
|
||||
|
||||
|
||||
@@ -12,10 +13,11 @@ class WebSocketState(enum.Enum):
|
||||
CONNECTING = 0
|
||||
CONNECTED = 1
|
||||
DISCONNECTED = 2
|
||||
RESPONSE = 3
|
||||
|
||||
|
||||
class WebSocketDisconnect(Exception):
|
||||
def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None:
|
||||
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
|
||||
self.code = code
|
||||
self.reason = reason or ""
|
||||
|
||||
@@ -37,10 +39,7 @@ class WebSocket(HTTPConnection):
|
||||
message = await self._receive()
|
||||
message_type = message["type"]
|
||||
if message_type != "websocket.connect":
|
||||
raise RuntimeError(
|
||||
'Expected ASGI message "websocket.connect", '
|
||||
f"but got {message_type!r}"
|
||||
)
|
||||
raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}')
|
||||
self.client_state = WebSocketState.CONNECTED
|
||||
return message
|
||||
elif self.client_state == WebSocketState.CONNECTED:
|
||||
@@ -48,16 +47,13 @@ class WebSocket(HTTPConnection):
|
||||
message_type = message["type"]
|
||||
if message_type not in {"websocket.receive", "websocket.disconnect"}:
|
||||
raise RuntimeError(
|
||||
'Expected ASGI message "websocket.receive" or '
|
||||
f'"websocket.disconnect", but got {message_type!r}'
|
||||
f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}'
|
||||
)
|
||||
if message_type == "websocket.disconnect":
|
||||
self.client_state = WebSocketState.DISCONNECTED
|
||||
return message
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Cannot call "receive" once a disconnect message has been received.'
|
||||
)
|
||||
raise RuntimeError('Cannot call "receive" once a disconnect message has been received.')
|
||||
|
||||
async def send(self, message: Message) -> None:
|
||||
"""
|
||||
@@ -65,13 +61,15 @@ class WebSocket(HTTPConnection):
|
||||
"""
|
||||
if self.application_state == WebSocketState.CONNECTING:
|
||||
message_type = message["type"]
|
||||
if message_type not in {"websocket.accept", "websocket.close"}:
|
||||
if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}:
|
||||
raise RuntimeError(
|
||||
'Expected ASGI message "websocket.accept" or '
|
||||
f'"websocket.close", but got {message_type!r}'
|
||||
'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", '
|
||||
f"but got {message_type!r}"
|
||||
)
|
||||
if message_type == "websocket.close":
|
||||
self.application_state = WebSocketState.DISCONNECTED
|
||||
elif message_type == "websocket.http.response.start":
|
||||
self.application_state = WebSocketState.RESPONSE
|
||||
else:
|
||||
self.application_state = WebSocketState.CONNECTED
|
||||
await self._send(message)
|
||||
@@ -79,16 +77,22 @@ class WebSocket(HTTPConnection):
|
||||
message_type = message["type"]
|
||||
if message_type not in {"websocket.send", "websocket.close"}:
|
||||
raise RuntimeError(
|
||||
'Expected ASGI message "websocket.send" or "websocket.close", '
|
||||
f"but got {message_type!r}"
|
||||
f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}'
|
||||
)
|
||||
if message_type == "websocket.close":
|
||||
self.application_state = WebSocketState.DISCONNECTED
|
||||
try:
|
||||
await self._send(message)
|
||||
except IOError:
|
||||
except OSError:
|
||||
self.application_state = WebSocketState.DISCONNECTED
|
||||
raise WebSocketDisconnect(code=1006)
|
||||
elif self.application_state == WebSocketState.RESPONSE:
|
||||
message_type = message["type"]
|
||||
if message_type != "websocket.http.response.body":
|
||||
raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}')
|
||||
if not message.get("more_body", False):
|
||||
self.application_state = WebSocketState.DISCONNECTED
|
||||
await self._send(message)
|
||||
else:
|
||||
raise RuntimeError('Cannot call "send" once a close message has been sent.')
|
||||
|
||||
@@ -99,12 +103,10 @@ class WebSocket(HTTPConnection):
|
||||
) -> None:
|
||||
headers = headers or []
|
||||
|
||||
if self.client_state == WebSocketState.CONNECTING:
|
||||
if self.client_state == WebSocketState.CONNECTING: # pragma: no branch
|
||||
# If we haven't yet seen the 'connect' message, then wait for it first.
|
||||
await self.receive()
|
||||
await self.send(
|
||||
{"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
|
||||
)
|
||||
await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers})
|
||||
|
||||
def _raise_on_disconnect(self, message: Message) -> None:
|
||||
if message["type"] == "websocket.disconnect":
|
||||
@@ -112,18 +114,14 @@ class WebSocket(HTTPConnection):
|
||||
|
||||
async def receive_text(self) -> str:
|
||||
if self.application_state != WebSocketState.CONNECTED:
|
||||
raise RuntimeError(
|
||||
'WebSocket is not connected. Need to call "accept" first.'
|
||||
)
|
||||
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
|
||||
message = await self.receive()
|
||||
self._raise_on_disconnect(message)
|
||||
return typing.cast(str, message["text"])
|
||||
|
||||
async def receive_bytes(self) -> bytes:
|
||||
if self.application_state != WebSocketState.CONNECTED:
|
||||
raise RuntimeError(
|
||||
'WebSocket is not connected. Need to call "accept" first.'
|
||||
)
|
||||
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
|
||||
message = await self.receive()
|
||||
self._raise_on_disconnect(message)
|
||||
return typing.cast(bytes, message["bytes"])
|
||||
@@ -132,9 +130,7 @@ class WebSocket(HTTPConnection):
|
||||
if mode not in {"text", "binary"}:
|
||||
raise RuntimeError('The "mode" argument should be "text" or "binary".')
|
||||
if self.application_state != WebSocketState.CONNECTED:
|
||||
raise RuntimeError(
|
||||
'WebSocket is not connected. Need to call "accept" first.'
|
||||
)
|
||||
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
|
||||
message = await self.receive()
|
||||
self._raise_on_disconnect(message)
|
||||
|
||||
@@ -181,9 +177,13 @@ class WebSocket(HTTPConnection):
|
||||
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
|
||||
|
||||
async def close(self, code: int = 1000, reason: str | None = None) -> None:
|
||||
await self.send(
|
||||
{"type": "websocket.close", "code": code, "reason": reason or ""}
|
||||
)
|
||||
await self.send({"type": "websocket.close", "code": code, "reason": reason or ""})
|
||||
|
||||
async def send_denial_response(self, response: Response) -> None:
|
||||
if "websocket.http.response" in self.scope.get("extensions", {}):
|
||||
await response(self.scope, self.receive, self.send)
|
||||
else:
|
||||
raise RuntimeError("The server doesn't support the Websocket Denial Response extension.")
|
||||
|
||||
|
||||
class WebSocketClose:
|
||||
@@ -192,6 +192,4 @@ class WebSocketClose:
|
||||
self.reason = reason or ""
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await send(
|
||||
{"type": "websocket.close", "code": self.code, "reason": self.reason}
|
||||
)
|
||||
await send({"type": "websocket.close", "code": self.code, "reason": self.reason})
|
||||
|
Reference in New Issue
Block a user