1
0
mirror of https://gitlab.com/MoonTestUse1/AdministrationItDepartmens.git synced 2025-08-14 00:25:46 +02:00

Проверка 09.02.2025

This commit is contained in:
MoonTestUse1
2025-02-09 01:11:49 +06:00
parent ce52f8a23a
commit 0aa3ef8fc2
5827 changed files with 14316 additions and 1906434 deletions

View File

@@ -1 +1 @@
__version__ = "0.36.3"
__version__ = "0.45.3"

View File

@@ -1,28 +0,0 @@
import hashlib
# Compat wrapper to always include the `usedforsecurity=...` parameter,
# which is only added from Python 3.9 onwards.
# We use this flag to indicate that we use `md5` hashes only for non-security
# cases (our ETag checksums).
# If we don't indicate that we're using MD5 for non-security related reasons,
# then attempting to use this function will raise an error when used
# environments which enable a strict "FIPs mode".
#
# See issue: https://github.com/encode/starlette/issues/1365
try:
# check if the Python version supports the parameter
# using usedforsecurity=False to avoid an exception on FIPS systems
# that reject usedforsecurity=True
hashlib.md5(b"data", usedforsecurity=False) # type: ignore[call-arg]
def md5_hexdigest(
data: bytes, *, usedforsecurity: bool = True
) -> str: # pragma: no cover
return hashlib.md5( # type: ignore[call-arg]
data, usedforsecurity=usedforsecurity
).hexdigest()
except TypeError: # pragma: no cover
def md5_hexdigest(data: bytes, *, usedforsecurity: bool = True) -> str:
return hashlib.md5(data).hexdigest()

View File

@@ -6,25 +6,14 @@ from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.types import (
ASGIApp,
ExceptionHandler,
HTTPExceptionHandler,
Message,
Receive,
Scope,
Send,
WebSocketExceptionHandler,
)
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
from starlette.websockets import WebSocket
ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
StatusHandlers = typing.Dict[int, ExceptionHandler]
ExceptionHandlers = dict[typing.Any, ExceptionHandler]
StatusHandlers = dict[int, ExceptionHandler]
def _lookup_exception_handler(
exc_handlers: ExceptionHandlers, exc: Exception
) -> ExceptionHandler | None:
def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -> ExceptionHandler | None:
for cls in type(exc).__mro__:
if cls in exc_handlers:
return exc_handlers[cls]
@@ -64,24 +53,13 @@ def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASG
raise exc
if response_started:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc
raise RuntimeError("Caught handled exception, but response already started.") from exc
if scope["type"] == "http":
nonlocal conn
handler = typing.cast(HTTPExceptionHandler, handler)
conn = typing.cast(Request, conn)
if is_async_callable(handler):
response = await handler(conn, exc)
else:
response = await run_in_threadpool(handler, conn, exc)
if is_async_callable(handler):
response = await handler(conn, exc)
else:
response = await run_in_threadpool(handler, conn, exc) # type: ignore
if response is not None:
await response(scope, receive, sender)
elif scope["type"] == "websocket":
handler = typing.cast(WebSocketExceptionHandler, handler)
conn = typing.cast(WebSocket, conn)
if is_async_callable(handler):
await handler(conn, exc)
else:
await run_in_threadpool(handler, conn, exc)
return wrapped_app

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import asyncio
import functools
import re
import sys
import typing
from contextlib import contextmanager
@@ -17,7 +16,7 @@ else: # pragma: no cover
has_exceptiongroups = True
if sys.version_info < (3, 11): # pragma: no cover
try:
from exceptiongroup import BaseExceptionGroup
from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found]
except ImportError:
has_exceptiongroups = False
@@ -26,41 +25,31 @@ AwaitableCallable = typing.Callable[..., typing.Awaitable[T]]
@typing.overload
def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]:
...
def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ...
@typing.overload
def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]:
...
def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]: ...
def is_async_callable(obj: typing.Any) -> typing.Any:
while isinstance(obj, functools.partial):
obj = obj.func
return asyncio.iscoroutinefunction(obj) or (
callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
)
return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__))
T_co = typing.TypeVar("T_co", covariant=True)
class AwaitableOrContextManager(
typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]
):
...
class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]): ...
class SupportsAsyncClose(typing.Protocol):
async def close(self) -> None:
... # pragma: no cover
async def close(self) -> None: ... # pragma: no cover
SupportsAsyncCloseType = typing.TypeVar(
"SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False
)
SupportsAsyncCloseType = typing.TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
@@ -86,14 +75,26 @@ def collapse_excgroups() -> typing.Generator[None, None, None]:
try:
yield
except BaseException as exc:
if has_exceptiongroups:
if has_exceptiongroups: # pragma: no cover
while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
exc = exc.exceptions[0] # pragma: no cover
exc = exc.exceptions[0]
raise exc
def get_route_path(scope: Scope) -> str:
path: str = scope["path"]
root_path = scope.get("root_path", "")
route_path = re.sub(r"^" + root_path, "", scope["path"])
return route_path
if not root_path:
return path
if not path.startswith(root_path):
return path
if path == root_path:
return ""
if path[len(root_path)] == "/":
return path[len(root_path) :]
return path

View File

@@ -10,7 +10,7 @@ else: # pragma: no cover
from typing_extensions import ParamSpec
from starlette.datastructures import State, URLPath
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware import Middleware, _MiddlewareFactory
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
@@ -25,34 +25,7 @@ P = ParamSpec("P")
class Starlette:
"""
Creates an application instance.
**Parameters:**
* **debug** - Boolean indicating if debug tracebacks should be returned on errors.
* **routes** - A list of routes to serve incoming HTTP and WebSocket requests.
* **middleware** - A list of middleware to run for every request. A starlette
application will always automatically include two middleware classes.
`ServerErrorMiddleware` is added as the very outermost middleware, to handle
any uncaught errors occurring anywhere in the entire stack.
`ExceptionMiddleware` is added as the very innermost middleware, to deal
with handled exception cases occurring in the routing or endpoints.
* **exception_handlers** - A mapping of either integer status codes,
or exception class types onto callables which handle the exceptions.
Exception handler callables should be of the form
`handler(request, exc) -> response` and may be either standard functions, or
async functions.
* **on_startup** - A list of callables to run on application startup.
Startup handler callables do not take any arguments, and may be either
standard functions, or async functions.
* **on_shutdown** - A list of callables to run on application shutdown.
Shutdown handler callables do not take any arguments, and may be either
standard functions, or async functions.
* **lifespan** - A lifespan context function, which can be used to perform
startup and shutdown tasks. This is a newer style that replaces the
`on_startup` and `on_shutdown` handlers. Use one or the other, not both.
"""
"""Creates an Starlette application."""
def __init__(
self: AppType,
@@ -64,6 +37,32 @@ class Starlette:
on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
lifespan: Lifespan[AppType] | None = None,
) -> None:
"""Initializes the application.
Parameters:
debug: Boolean indicating if debug tracebacks should be returned on errors.
routes: A list of routes to serve incoming HTTP and WebSocket requests.
middleware: A list of middleware to run for every request. A starlette
application will always automatically include two middleware classes.
`ServerErrorMiddleware` is added as the very outermost middleware, to handle
any uncaught errors occurring anywhere in the entire stack.
`ExceptionMiddleware` is added as the very innermost middleware, to deal
with handled exception cases occurring in the routing or endpoints.
exception_handlers: A mapping of either integer status codes,
or exception class types onto callables which handle the exceptions.
Exception handler callables should be of the form
`handler(request, exc) -> response` and may be either standard functions, or
async functions.
on_startup: A list of callables to run on application startup.
Startup handler callables do not take any arguments, and may be either
standard functions, or async functions.
on_shutdown: A list of callables to run on application shutdown.
Shutdown handler callables do not take any arguments, and may be either
standard functions, or async functions.
lifespan: A lifespan context function, which can be used to perform
startup and shutdown tasks. This is a newer style that replaces the
`on_startup` and `on_shutdown` handlers. Use one or the other, not both.
"""
# The lifespan context function is a newer style that replaces
# on_startup / on_shutdown handlers. Use one or the other, not both.
assert lifespan is None or (
@@ -72,21 +71,15 @@ class Starlette:
self.debug = debug
self.state = State()
self.router = Router(
routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan
)
self.exception_handlers = (
{} if exception_handlers is None else dict(exception_handlers)
)
self.router = Router(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan)
self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers)
self.user_middleware = [] if middleware is None else list(middleware)
self.middleware_stack: typing.Optional[ASGIApp] = None
self.middleware_stack: ASGIApp | None = None
def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
error_handler = None
exception_handlers: dict[
typing.Any, typing.Callable[[Request, Exception], Response]
] = {}
exception_handlers: dict[typing.Any, typing.Callable[[Request, Exception], Response]] = {}
for key, value in self.exception_handlers.items():
if key in (500, Exception):
@@ -97,16 +90,12 @@ class Starlette:
middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ self.user_middleware
+ [
Middleware(
ExceptionMiddleware, handlers=exception_handlers, debug=debug
)
]
+ [Middleware(ExceptionMiddleware, handlers=exception_handlers, debug=debug)]
)
app = self.router
for cls, args, kwargs in reversed(middleware):
app = cls(app=app, *args, **kwargs)
app = cls(app, *args, **kwargs)
return app
@property
@@ -123,7 +112,7 @@ class Starlette:
await self.middleware_stack(scope, receive, send)
def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
return self.router.on_event(event_type) # pragma: nocover
return self.router.on_event(event_type) # pragma: no cover
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:
self.router.mount(path, app=app, name=name) # pragma: no cover
@@ -133,7 +122,7 @@ class Starlette:
def add_middleware(
self,
middleware_class: typing.Type[_MiddlewareClass[P]],
middleware_class: _MiddlewareFactory[P],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
@@ -143,7 +132,7 @@ class Starlette:
def add_exception_handler(
self,
exc_class_or_status_code: int | typing.Type[Exception],
exc_class_or_status_code: int | type[Exception],
handler: ExceptionHandler,
) -> None: # pragma: no cover
self.exception_handlers[exc_class_or_status_code] = handler
@@ -159,13 +148,11 @@ class Starlette:
self,
path: str,
route: typing.Callable[[Request], typing.Awaitable[Response] | Response],
methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
methods: list[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
) -> None: # pragma: no cover
self.router.add_route(
path, route, methods=methods, name=name, include_in_schema=include_in_schema
)
self.router.add_route(path, route, methods=methods, name=name, include_in_schema=include_in_schema)
def add_websocket_route(
self,
@@ -175,16 +162,14 @@ class Starlette:
) -> None: # pragma: no cover
self.router.add_websocket_route(path, route, name=name)
def exception_handler(
self, exc_class_or_status_code: int | typing.Type[Exception]
) -> typing.Callable: # type: ignore[type-arg]
def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> typing.Callable: # type: ignore[type-arg]
warnings.warn(
"The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
"Refer to https://www.starlette.io/exceptions/ for the recommended approach.", # noqa: E501
"The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://www.starlette.io/exceptions/ for the recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_exception_handler(exc_class_or_status_code, func)
return func
@@ -205,12 +190,12 @@ class Starlette:
>>> app = Starlette(routes=routes)
"""
warnings.warn(
"The `route` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
"Refer to https://www.starlette.io/routing/ for the recommended approach.", # noqa: E501
"The `route` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://www.starlette.io/routing/ for the recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.router.add_route(
path,
func,
@@ -231,18 +216,18 @@ class Starlette:
>>> app = Starlette(routes=routes)
"""
warnings.warn(
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
"Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.router.add_websocket_route(path, func, name=name)
return func
return decorator
def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
@@ -251,15 +236,13 @@ class Starlette:
>>> app = Starlette(middleware=middleware)
"""
warnings.warn(
"The `middleware` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
"Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.", # noqa: E501
"The `middleware` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.",
DeprecationWarning,
)
assert (
middleware_type == "http"
), 'Currently only middleware("http") is supported.'
assert middleware_type == "http", 'Currently only middleware("http") is supported.'
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_middleware(BaseHTTPMiddleware, dispatch=func)
return func

View File

@@ -31,9 +31,7 @@ def requires(
scopes: str | typing.Sequence[str],
status_code: int = 403,
redirect: str | None = None,
) -> typing.Callable[
[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]
]:
) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]:
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
def decorator(
@@ -45,17 +43,13 @@ def requires(
type_ = parameter.name
break
else:
raise Exception(
f'No "request" or "websocket" argument on function "{func}"'
)
raise Exception(f'No "request" or "websocket" argument on function "{func}"')
if type_ == "websocket":
# Handle websocket functions. (Always async)
@functools.wraps(func)
async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
websocket = kwargs.get(
"websocket", args[idx] if idx < len(args) else None
)
websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None)
assert isinstance(websocket, WebSocket)
if not has_required_scope(websocket, scopes_list):
@@ -75,10 +69,7 @@ def requires(
if not has_required_scope(request, scopes_list):
if redirect is not None:
orig_request_qparam = urlencode({"next": str(request.url)})
next_url = "{redirect_path}?{orig_request}".format(
redirect_path=request.url_for(redirect),
orig_request=orig_request_qparam,
)
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
return RedirectResponse(url=next_url, status_code=303)
raise HTTPException(status_code=status_code)
return await func(*args, **kwargs)
@@ -95,10 +86,7 @@ def requires(
if not has_required_scope(request, scopes_list):
if redirect is not None:
orig_request_qparam = urlencode({"next": str(request.url)})
next_url = "{redirect_path}?{orig_request}".format(
redirect_path=request.url_for(redirect),
orig_request=orig_request_qparam,
)
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
return RedirectResponse(url=next_url, status_code=303)
raise HTTPException(status_code=status_code)
return func(*args, **kwargs)
@@ -113,9 +101,7 @@ class AuthenticationError(Exception):
class AuthenticationBackend:
async def authenticate(
self, conn: HTTPConnection
) -> tuple[AuthCredentials, BaseUser] | None:
async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
raise NotImplementedError() # pragma: no cover

View File

@@ -15,9 +15,7 @@ P = ParamSpec("P")
class BackgroundTask:
def __init__(
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
) -> None:
def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
self.func = func
self.args = args
self.kwargs = kwargs
@@ -34,9 +32,7 @@ class BackgroundTasks(BackgroundTask):
def __init__(self, tasks: typing.Sequence[BackgroundTask] | None = None):
self.tasks = list(tasks) if tasks else []
def add_task(
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
) -> None:
def add_task(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
task = BackgroundTask(func, *args, **kwargs)
self.tasks.append(task)

View File

@@ -16,16 +16,15 @@ P = ParamSpec("P")
T = typing.TypeVar("T")
async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg] # noqa: E501
async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg]
warnings.warn(
"run_until_first_complete is deprecated "
"and will be removed in a future version.",
"run_until_first_complete is deprecated and will be removed in a future version.",
DeprecationWarning,
)
async with anyio.create_task_group() as task_group:
async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg] # noqa: E501
async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg]
await func()
task_group.cancel_scope.cancel()
@@ -33,13 +32,9 @@ async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None:
task_group.start_soon(run, functools.partial(func, **kwargs))
async def run_in_threadpool(
func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> T:
if kwargs: # pragma: no cover
# run_sync doesn't accept 'kwargs', so bind them in here
func = functools.partial(func, **kwargs)
return await anyio.to_thread.run_sync(func, *args)
async def run_in_threadpool(func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func = functools.partial(func, *args, **kwargs)
return await anyio.to_thread.run_sync(func)
class _StopIteration(Exception):

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import os
import typing
import warnings
from pathlib import Path
@@ -16,7 +17,7 @@ class EnvironError(Exception):
class Environ(typing.MutableMapping[str, str]):
def __init__(self, environ: typing.MutableMapping[str, str] = os.environ):
self._environ = environ
self._has_been_read: typing.Set[str] = set()
self._has_been_read: set[str] = set()
def __getitem__(self, key: str) -> str:
self._has_been_read.add(key)
@@ -24,18 +25,12 @@ class Environ(typing.MutableMapping[str, str]):
def __setitem__(self, key: str, value: str) -> None:
if key in self._has_been_read:
raise EnvironError(
f"Attempting to set environ['{key}'], but the value has already been "
"read."
)
raise EnvironError(f"Attempting to set environ['{key}'], but the value has already been read.")
self._environ.__setitem__(key, value)
def __delitem__(self, key: str) -> None:
if key in self._has_been_read:
raise EnvironError(
f"Attempting to delete environ['{key}'], but the value has already "
"been read."
)
raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already been read.")
self._environ.__delitem__(key)
def __iter__(self) -> typing.Iterator[str]:
@@ -59,23 +54,21 @@ class Config:
) -> None:
self.environ = environ
self.env_prefix = env_prefix
self.file_values: typing.Dict[str, str] = {}
self.file_values: dict[str, str] = {}
if env_file is not None:
if not os.path.isfile(env_file):
raise FileNotFoundError(f"Config file '{env_file}' not found.")
self.file_values = self._read_file(env_file)
warnings.warn(f"Config file '{env_file}' not found.")
else:
self.file_values = self._read_file(env_file)
@typing.overload
def __call__(self, key: str, *, default: None) -> str | None:
...
def __call__(self, key: str, *, default: None) -> str | None: ...
@typing.overload
def __call__(self, key: str, cast: type[T], default: T = ...) -> T:
...
def __call__(self, key: str, cast: type[T], default: T = ...) -> T: ...
@typing.overload
def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str:
...
def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str: ...
@typing.overload
def __call__(
@@ -83,12 +76,10 @@ class Config:
key: str,
cast: typing.Callable[[typing.Any], T] = ...,
default: typing.Any = ...,
) -> T:
...
) -> T: ...
@typing.overload
def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str:
...
def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ...
def __call__(
self,
@@ -116,7 +107,7 @@ class Config:
raise KeyError(f"Config '{key}' is missing, and has no default.")
def _read_file(self, file_name: str | Path) -> dict[str, str]:
file_values: typing.Dict[str, str] = {}
file_values: dict[str, str] = {}
with open(file_name) as input_file:
for line in input_file.readlines():
line = line.strip()
@@ -139,13 +130,9 @@ class Config:
mapping = {"true": True, "1": True, "false": False, "0": False}
value = value.lower()
if value not in mapping:
raise ValueError(
f"Config '{key}' has value '{value}'. Not a valid bool."
)
raise ValueError(f"Config '{key}' has value '{value}'. Not a valid bool.")
return mapping[value]
try:
return cast(value)
except (TypeError, ValueError):
raise ValueError(
f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}."
)
raise ValueError(f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}.")

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import math
import typing
import uuid
@@ -65,7 +67,7 @@ class FloatConvertor(Convertor[float]):
class UUIDConvertor(Convertor[uuid.UUID]):
regex = "[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
regex = "[0-9a-fA-F]{8}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{12}"
def convert(self, value: str) -> uuid.UUID:
return uuid.UUID(value)
@@ -74,7 +76,7 @@ class UUIDConvertor(Convertor[uuid.UUID]):
return str(value)
CONVERTOR_TYPES: typing.Dict[str, Convertor[typing.Any]] = {
CONVERTOR_TYPES: dict[str, Convertor[typing.Any]] = {
"str": StringConvertor(),
"path": PathConvertor(),
"int": IntegerConvertor(),

View File

@@ -108,12 +108,7 @@ class URL:
return self.scheme in ("https", "wss")
def replace(self, **kwargs: typing.Any) -> URL:
if (
"username" in kwargs
or "password" in kwargs
or "hostname" in kwargs
or "port" in kwargs
):
if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
hostname = kwargs.pop("hostname", None)
port = kwargs.pop("port", self.port)
username = kwargs.pop("username", self.username)
@@ -150,7 +145,7 @@ class URL:
query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
return self.replace(query=query)
def remove_query_params(self, keys: str | typing.Sequence[str]) -> "URL":
def remove_query_params(self, keys: str | typing.Sequence[str]) -> URL:
if isinstance(keys, str):
keys = [keys]
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
@@ -178,7 +173,7 @@ class URLPath(str):
Used by the routing to return `url_path_for` matches.
"""
def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath":
def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
assert protocol in ("http", "websocket", "")
return str.__new__(cls, path)
@@ -251,30 +246,25 @@ class CommaSeparatedStrings(typing.Sequence[str]):
class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
_dict: typing.Dict[_KeyType, _CovariantValueType]
_dict: dict[_KeyType, _CovariantValueType]
def __init__(
self,
*args: ImmutableMultiDict[_KeyType, _CovariantValueType]
| typing.Mapping[_KeyType, _CovariantValueType]
| typing.Iterable[typing.Tuple[_KeyType, _CovariantValueType]],
| typing.Iterable[tuple[_KeyType, _CovariantValueType]],
**kwargs: typing.Any,
) -> None:
assert len(args) < 2, "Too many arguments."
value: typing.Any = args[0] if args else []
if kwargs:
value = (
ImmutableMultiDict(value).multi_items()
+ ImmutableMultiDict(kwargs).multi_items()
)
value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
if not value:
_items: list[tuple[typing.Any, typing.Any]] = []
elif hasattr(value, "multi_items"):
value = typing.cast(
ImmutableMultiDict[_KeyType, _CovariantValueType], value
)
value = typing.cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
_items = list(value.multi_items())
elif hasattr(value, "items"):
value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value)
@@ -371,9 +361,7 @@ class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
def update(
self,
*args: MultiDict
| typing.Mapping[typing.Any, typing.Any]
| list[tuple[typing.Any, typing.Any]],
*args: MultiDict | typing.Mapping[typing.Any, typing.Any] | list[tuple[typing.Any, typing.Any]],
**kwargs: typing.Any,
) -> None:
value = MultiDict(*args, **kwargs)
@@ -403,9 +391,7 @@ class QueryParams(ImmutableMultiDict[str, str]):
if isinstance(value, str):
super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
elif isinstance(value, bytes):
super().__init__(
parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs
)
super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
else:
super().__init__(*args, **kwargs) # type: ignore[arg-type]
self._list = [(str(k), str(v)) for k, v in self._list]
@@ -490,9 +476,7 @@ class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]):
def __init__(
self,
*args: FormData
| typing.Mapping[str, str | UploadFile]
| list[tuple[str, str | UploadFile]],
*args: FormData | typing.Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
**kwargs: str | UploadFile,
) -> None:
super().__init__(*args, **kwargs)
@@ -518,10 +502,7 @@ class Headers(typing.Mapping[str, str]):
if headers is not None:
assert raw is None, 'Cannot set both "headers" and "raw".'
assert scope is None, 'Cannot set both "headers" and "scope".'
self._list = [
(key.lower().encode("latin-1"), value.encode("latin-1"))
for key, value in headers.items()
]
self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
elif raw is not None:
assert scope is None, 'Cannot set both "raw" and "scope".'
self._list = raw
@@ -541,18 +522,11 @@ class Headers(typing.Mapping[str, str]):
return [value.decode("latin-1") for key, value in self._list]
def items(self) -> list[tuple[str, str]]: # type: ignore[override]
return [
(key.decode("latin-1"), value.decode("latin-1"))
for key, value in self._list
]
return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
def getlist(self, key: str) -> list[str]:
get_header_key = key.lower().encode("latin-1")
return [
item_value.decode("latin-1")
for item_key, item_value in self._list
if item_key == get_header_key
]
return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
def mutablecopy(self) -> MutableHeaders:
return MutableHeaders(raw=self._list[:])
@@ -599,7 +573,7 @@ class MutableHeaders(Headers):
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")
found_indexes: "typing.List[int]" = []
found_indexes: list[int] = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == set_key:
found_indexes.append(idx)
@@ -619,7 +593,7 @@ class MutableHeaders(Headers):
"""
del_key = key.lower().encode("latin-1")
pop_indexes: "typing.List[int]" = []
pop_indexes: list[int] = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == del_key:
pop_indexes.append(idx)

View File

@@ -30,15 +30,9 @@ class HTTPEndpoint:
async def dispatch(self) -> None:
request = Request(self.scope, receive=self.receive)
handler_name = (
"get"
if request.method == "HEAD" and not hasattr(self, "head")
else request.method.lower()
)
handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower()
handler: typing.Callable[[Request], typing.Any] = getattr(
self, handler_name, self.method_not_allowed
)
handler: typing.Callable[[Request], typing.Any] = getattr(self, handler_name, self.method_not_allowed)
is_async = is_async_callable(handler)
if is_async:
response = await handler(request)
@@ -80,10 +74,8 @@ class WebSocketEndpoint:
if message["type"] == "websocket.receive":
data = await self.decode(websocket, message)
await self.on_receive(websocket, data)
elif message["type"] == "websocket.disconnect":
close_code = int(
message.get("code") or status.WS_1000_NORMAL_CLOSURE
)
elif message["type"] == "websocket.disconnect": # pragma: no branch
close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE)
break
except Exception as exc:
close_code = status.WS_1011_INTERNAL_ERROR
@@ -116,9 +108,7 @@ class WebSocketEndpoint:
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
raise RuntimeError("Malformed JSON data received.")
assert (
self.encoding is None
), f"Unsupported 'encoding' attribute {self.encoding}"
assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}"
return message["text"] if message.get("text") else message["bytes"]
async def on_connect(self, websocket: WebSocket) -> None:

View File

@@ -1,19 +1,11 @@
from __future__ import annotations
import http
import typing
import warnings
__all__ = ("HTTPException", "WebSocketException")
from collections.abc import Mapping
class HTTPException(Exception):
def __init__(
self,
status_code: int,
detail: str | None = None,
headers: dict[str, str] | None = None,
) -> None:
def __init__(self, status_code: int, detail: str | None = None, headers: Mapping[str, str] | None = None) -> None:
if detail is None:
detail = http.HTTPStatus(status_code).phrase
self.status_code = status_code
@@ -39,24 +31,3 @@ class WebSocketException(Exception):
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(code={self.code!r}, reason={self.reason!r})"
__deprecated__ = "ExceptionMiddleware"
def __getattr__(name: str) -> typing.Any: # pragma: no cover
if name == __deprecated__:
from starlette.middleware.exceptions import ExceptionMiddleware
warnings.warn(
f"{__deprecated__} is deprecated on `starlette.exceptions`. "
f"Import it from `starlette.middleware.exceptions` instead.",
category=DeprecationWarning,
stacklevel=3,
)
return ExceptionMiddleware
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
def __dir__() -> list[str]:
return sorted(list(__all__) + [__deprecated__]) # pragma: no cover

View File

@@ -8,12 +8,20 @@ from urllib.parse import unquote_plus
from starlette.datastructures import FormData, Headers, UploadFile
try:
import multipart
from multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: nocover
parse_options_header = None
multipart = None
if typing.TYPE_CHECKING:
import python_multipart as multipart
from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header
else:
try:
try:
import python_multipart as multipart
from python_multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
import multipart
from multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
multipart = None
parse_options_header = None
class FormMessage(Enum):
@@ -28,12 +36,12 @@ class FormMessage(Enum):
class MultipartPart:
content_disposition: bytes | None = None
field_name: str = ""
data: bytes = b""
data: bytearray = field(default_factory=bytearray)
file: UploadFile | None = None
item_headers: list[tuple[bytes, bytes]] = field(default_factory=list)
def _user_safe_decode(src: bytes, codec: str) -> str:
def _user_safe_decode(src: bytes | bytearray, codec: str) -> str:
try:
return src.decode(codec)
except (UnicodeDecodeError, LookupError):
@@ -46,12 +54,8 @@ class MultiPartException(Exception):
class FormParser:
def __init__(
self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
) -> None:
assert (
multipart is not None
), "The `python-multipart` library must be installed to use form parsing."
def __init__(self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]) -> None:
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.messages: list[tuple[FormMessage, bytes]] = []
@@ -78,7 +82,7 @@ class FormParser:
async def parse(self) -> FormData:
# Callbacks dictionary.
callbacks = {
callbacks: QuerystringCallbacks = {
"on_field_start": self.on_field_start,
"on_field_name": self.on_field_name,
"on_field_data": self.on_field_data,
@@ -91,7 +95,7 @@ class FormParser:
field_name = b""
field_value = b""
items: list[tuple[str, typing.Union[str, UploadFile]]] = []
items: list[tuple[str, str | UploadFile]] = []
# Feed the parser with data from the request.
async for chunk in self.stream:
@@ -118,7 +122,7 @@ class FormParser:
class MultiPartParser:
max_file_size = 1024 * 1024
max_file_size = 1024 * 1024 # 1MB
def __init__(
self,
@@ -127,10 +131,9 @@ class MultiPartParser:
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024, # 1MB
) -> None:
assert (
multipart is not None
), "The `python-multipart` library must be installed to use form parsing."
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.max_files = max_files
@@ -145,6 +148,7 @@ class MultiPartParser:
self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
self._file_parts_to_finish: list[MultipartPart] = []
self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
self.max_part_size = max_part_size
def on_part_begin(self) -> None:
self._current_part = MultipartPart()
@@ -152,7 +156,9 @@ class MultiPartParser:
def on_part_data(self, data: bytes, start: int, end: int) -> None:
message_bytes = data[start:end]
if self._current_part.file is None:
self._current_part.data += message_bytes
if len(self._current_part.data) + len(message_bytes) > self.max_part_size:
raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.")
self._current_part.data.extend(message_bytes)
else:
self._file_parts_to_write.append((self._current_part, message_bytes))
@@ -181,30 +187,20 @@ class MultiPartParser:
field = self._current_partial_header_name.lower()
if field == b"content-disposition":
self._current_part.content_disposition = self._current_partial_header_value
self._current_part.item_headers.append(
(field, self._current_partial_header_value)
)
self._current_part.item_headers.append((field, self._current_partial_header_value))
self._current_partial_header_name = b""
self._current_partial_header_value = b""
def on_headers_finished(self) -> None:
disposition, options = parse_options_header(
self._current_part.content_disposition
)
disposition, options = parse_options_header(self._current_part.content_disposition)
try:
self._current_part.field_name = _user_safe_decode(
options[b"name"], self._charset
)
self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset)
except KeyError:
raise MultiPartException(
'The Content-Disposition header field "name" must be ' "provided."
)
raise MultiPartException('The Content-Disposition header field "name" must be provided.')
if b"filename" in options:
self._current_files += 1
if self._current_files > self.max_files:
raise MultiPartException(
f"Too many files. Maximum number of files is {self.max_files}."
)
raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
filename = _user_safe_decode(options[b"filename"], self._charset)
tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
self._files_to_close_on_error.append(tempfile)
@@ -217,9 +213,7 @@ class MultiPartParser:
else:
self._current_fields += 1
if self._current_fields > self.max_fields:
raise MultiPartException(
f"Too many fields. Maximum number of fields is {self.max_fields}."
)
raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.")
self._current_part.file = None
def on_end(self) -> None:
@@ -238,7 +232,7 @@ class MultiPartParser:
raise MultiPartException("Missing boundary in multipart.")
# Callbacks dictionary.
callbacks = {
callbacks: MultipartCallbacks = {
"on_part_begin": self.on_part_begin,
"on_part_data": self.on_part_data,
"on_part_end": self.on_part_end,

View File

@@ -1,30 +1,27 @@
from __future__ import annotations
import sys
from typing import Any, Iterator, Protocol
from collections.abc import Iterator
from typing import Any, Protocol
if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp
P = ParamSpec("P")
class _MiddlewareClass(Protocol[P]):
def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None:
... # pragma: no cover
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
... # pragma: no cover
class _MiddlewareFactory(Protocol[P]):
def __call__(self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover
class Middleware:
def __init__(
self,
cls: type[_MiddlewareClass[P]],
cls: _MiddlewareFactory[P],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
@@ -40,5 +37,6 @@ class Middleware:
class_name = self.__class__.__name__
args_strings = [f"{value!r}" for value in self.args]
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
args_repr = ", ".join([self.cls.__name__] + args_strings + option_strings)
name = getattr(self.cls, "__name__", "")
args_repr = ", ".join([name] + args_strings + option_strings)
return f"{class_name}({args_repr})"

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import typing
from starlette.authentication import (
@@ -16,15 +18,13 @@ class AuthenticationMiddleware:
self,
app: ASGIApp,
backend: AuthenticationBackend,
on_error: typing.Optional[
typing.Callable[[HTTPConnection, AuthenticationError], Response]
] = None,
on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
) -> None:
self.app = app
self.backend = backend
self.on_error: typing.Callable[
[HTTPConnection, AuthenticationError], Response
] = on_error if on_error is not None else self.default_on_error
self.on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] = (
on_error if on_error is not None else self.default_on_error
)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ["http", "websocket"]:

View File

@@ -1,18 +1,16 @@
from __future__ import annotations
import typing
import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from starlette._utils import collapse_excgroups
from starlette.background import BackgroundTask
from starlette.requests import ClientDisconnect, Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.responses import AsyncContentStream, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]]
T = typing.TypeVar("T")
@@ -54,6 +52,7 @@ class _CachedRequest(Request):
# at this point a disconnect is all that we should be receiving
# if we get something else, things went wrong somewhere
raise RuntimeError(f"Unexpected message received: {msg['type']}")
self._wrapped_rcv_disconnected = True
return msg
# wrapped_rcv state 3: not yet consumed
@@ -92,9 +91,7 @@ class _CachedRequest(Request):
class BaseHTTPMiddleware:
def __init__(
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
) -> None:
def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None:
self.app = app
self.dispatch_func = self.dispatch if dispatch is None else dispatch
@@ -108,10 +105,7 @@ class BaseHTTPMiddleware:
response_sent = anyio.Event()
async def call_next(request: Request) -> Response:
app_exc: typing.Optional[Exception] = None
send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
send_stream, recv_stream = anyio.create_memory_object_stream()
app_exc: Exception | None = None
async def receive_or_disconnect() -> Message:
if response_sent.is_set():
@@ -132,10 +126,6 @@ class BaseHTTPMiddleware:
return message
async def close_recv_stream_on_response_sent() -> None:
await response_sent.wait()
recv_stream.close()
async def send_no_error(message: Message) -> None:
try:
await send_stream.send(message)
@@ -146,13 +136,12 @@ class BaseHTTPMiddleware:
async def coro() -> None:
nonlocal app_exc
async with send_stream:
with send_stream:
try:
await self.app(scope, receive_or_disconnect, send_no_error)
except Exception as exc:
app_exc = exc
task_group.start_soon(close_recv_stream_on_response_sent)
task_group.start_soon(coro)
try:
@@ -168,50 +157,65 @@ class BaseHTTPMiddleware:
assert message["type"] == "http.response.start"
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
async with recv_stream:
async for message in recv_stream:
assert message["type"] == "http.response.body"
body = message.get("body", b"")
if body:
yield body
if not message.get("more_body", False):
break
async for message in recv_stream:
assert message["type"] == "http.response.body"
body = message.get("body", b"")
if body:
yield body
if not message.get("more_body", False):
break
if app_exc is not None:
raise app_exc
response = _StreamingResponse(
status_code=message["status"], content=body_stream(), info=info
)
response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
response.raw_headers = message["headers"]
return response
with collapse_excgroups():
streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream()
send_stream, recv_stream = streams
with recv_stream, send_stream, collapse_excgroups():
async with anyio.create_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()
recv_stream.close()
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
raise NotImplementedError() # pragma: no cover
class _StreamingResponse(StreamingResponse):
class _StreamingResponse(Response):
def __init__(
self,
content: ContentStream,
content: AsyncContentStream,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
info: typing.Optional[typing.Mapping[str, typing.Any]] = None,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
info: typing.Mapping[str, typing.Any] | None = None,
) -> None:
self._info = info
super().__init__(content, status_code, headers, media_type, background)
self.info = info
self.body_iterator = content
self.status_code = status_code
self.media_type = media_type
self.init_headers(headers)
self.background = None
async def stream_response(self, send: Send) -> None:
if self._info:
await send({"type": "http.response.debug", "info": self._info})
return await super().stream_response(send)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.info is not None:
await send({"type": "http.response.debug", "info": self.info})
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
async for chunk in self.body_iterator:
await send({"type": "http.response.body", "body": chunk, "more_body": True})
await send({"type": "http.response.body", "body": b"", "more_body": False})
if self.background:
await self.background()

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import functools
import re
import typing
@@ -18,7 +20,7 @@ class CORSMiddleware:
allow_methods: typing.Sequence[str] = ("GET",),
allow_headers: typing.Sequence[str] = (),
allow_credentials: bool = False,
allow_origin_regex: typing.Optional[str] = None,
allow_origin_regex: str | None = None,
expose_headers: typing.Sequence[str] = (),
max_age: int = 600,
) -> None:
@@ -94,9 +96,7 @@ class CORSMiddleware:
if self.allow_all_origins:
return True
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(
origin
):
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
return True
return origin in self.allow_origins
@@ -139,15 +139,11 @@ class CORSMiddleware:
return PlainTextResponse("OK", status_code=200, headers=headers)
async def simple_response(
self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
) -> None:
async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None:
send = functools.partial(self.send, send=send, request_headers=request_headers)
await self.app(scope, receive, send)
async def send(
self, message: Message, send: Send, request_headers: Headers
) -> None:
async def send(self, message: Message, send: Send, request_headers: Headers) -> None:
if message["type"] != "http.response.start":
await send(message)
return

View File

@@ -1,5 +1,8 @@
from __future__ import annotations
import html
import inspect
import sys
import traceback
import typing
@@ -137,9 +140,7 @@ class ServerErrorMiddleware:
def __init__(
self,
app: ASGIApp,
handler: typing.Optional[
typing.Callable[[Request, Exception], typing.Any]
] = None,
handler: typing.Callable[[Request, Exception], typing.Any] | None = None,
debug: bool = False,
) -> None:
self.app = app
@@ -185,9 +186,7 @@ class ServerErrorMiddleware:
# to optionally raise the error within the test case.
raise exc
def format_line(
self, index: int, line: str, frame_lineno: int, frame_index: int
) -> str:
def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str:
values = {
# HTML escape - line could contain < or >
"line": html.escape(line).replace(" ", "&nbsp"),
@@ -224,9 +223,7 @@ class ServerErrorMiddleware:
return FRAME_TEMPLATE.format(**values)
def generate_html(self, exc: Exception, limit: int = 7) -> str:
traceback_obj = traceback.TracebackException.from_exception(
exc, capture_locals=True
)
traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True)
exc_html = ""
is_collapsed = False
@@ -237,11 +234,13 @@ class ServerErrorMiddleware:
exc_html += self.generate_frame_html(frame, is_collapsed)
is_collapsed = True
if sys.version_info >= (3, 13): # pragma: no cover
exc_type_str = traceback_obj.exc_type_str
else: # pragma: no cover
exc_type_str = traceback_obj.exc_type.__name__
# escape error class and text
error = (
f"{html.escape(traceback_obj.exc_type.__name__)}: "
f"{html.escape(str(traceback_obj))}"
)
error = f"{html.escape(exc_type_str)}: {html.escape(str(traceback_obj))}"
return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html)

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import typing
from starlette._exception_handler import (
@@ -16,9 +18,7 @@ class ExceptionMiddleware:
def __init__(
self,
app: ASGIApp,
handlers: typing.Optional[
typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
] = None,
handlers: typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | None = None,
debug: bool = False,
) -> None:
self.app = app
@@ -28,13 +28,13 @@ class ExceptionMiddleware:
HTTPException: self.http_exception,
WebSocketException: self.websocket_exception,
}
if handlers is not None:
if handlers is not None: # pragma: no branch
for key, value in handlers.items():
self.add_exception_handler(key, value)
def add_exception_handler(
self,
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
exc_class_or_status_code: int | type[Exception],
handler: typing.Callable[[Request, Exception], Response],
) -> None:
if isinstance(exc_class_or_status_code, int):
@@ -53,7 +53,7 @@ class ExceptionMiddleware:
self._status_handlers,
)
conn: typing.Union[Request, WebSocket]
conn: Request | WebSocket
if scope["type"] == "http":
conn = Request(scope, receive, send)
else:
@@ -65,9 +65,7 @@ class ExceptionMiddleware:
assert isinstance(exc, HTTPException)
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
return PlainTextResponse(
exc.detail, status_code=exc.status_code, headers=exc.headers
)
return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
assert isinstance(exc, WebSocketException)

View File

@@ -7,20 +7,16 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send
class GZipMiddleware:
def __init__(
self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9
) -> None:
def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
self.app = app
self.minimum_size = minimum_size
self.compresslevel = compresslevel
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
if scope["type"] == "http": # pragma: no branch
headers = Headers(scope=scope)
if "gzip" in headers.get("Accept-Encoding", ""):
responder = GZipResponder(
self.app, self.minimum_size, compresslevel=self.compresslevel
)
responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
await responder(scope, receive, send)
return
await self.app(scope, receive, send)
@@ -35,13 +31,12 @@ class GZipResponder:
self.started = False
self.content_encoding_set = False
self.gzip_buffer = io.BytesIO()
self.gzip_file = gzip.GzipFile(
mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel
)
self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.send = send
await self.app(scope, receive, self.send_with_gzip)
with self.gzip_buffer, self.gzip_file:
await self.app(scope, receive, self.send_with_gzip)
async def send_with_gzip(self, message: Message) -> None:
message_type = message["type"]
@@ -93,7 +88,7 @@ class GZipResponder:
await self.send(self.initial_message)
await self.send(message)
elif message_type == "http.response.body":
elif message_type == "http.response.body": # pragma: no branch
# Remaining body in streaming GZip response.
body = message.get("body", b"")
more_body = message.get("more_body", False)

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import json
import typing
from base64 import b64decode, b64encode
@@ -14,13 +16,13 @@ class SessionMiddleware:
def __init__(
self,
app: ASGIApp,
secret_key: typing.Union[str, Secret],
secret_key: str | Secret,
session_cookie: str = "session",
max_age: typing.Optional[int] = 14 * 24 * 60 * 60, # 14 days, in seconds
max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds
path: str = "/",
same_site: typing.Literal["lax", "strict", "none"] = "lax",
https_only: bool = False,
domain: typing.Optional[str] = None,
domain: str | None = None,
) -> None:
self.app = app
self.signer = itsdangerous.TimestampSigner(str(secret_key))
@@ -59,7 +61,7 @@ class SessionMiddleware:
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
data = self.signer.sign(data)
headers = MutableHeaders(scope=message)
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( # noqa E501
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
session_cookie=self.session_cookie,
data=data.decode("utf-8"),
path=self.path,
@@ -70,7 +72,7 @@ class SessionMiddleware:
elif not initial_session_was_empty:
# The session has been cleared.
headers = MutableHeaders(scope=message)
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( # noqa E501
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(
session_cookie=self.session_cookie,
data="null",
path=self.path,

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import typing
from starlette.datastructures import URL, Headers
@@ -11,7 +13,7 @@ class TrustedHostMiddleware:
def __init__(
self,
app: ASGIApp,
allowed_hosts: typing.Optional[typing.Sequence[str]] = None,
allowed_hosts: typing.Sequence[str] | None = None,
www_redirect: bool = True,
) -> None:
if allowed_hosts is None:
@@ -39,9 +41,7 @@ class TrustedHostMiddleware:
is_valid_host = False
found_www_redirect = False
for pattern in self.allowed_hosts:
if host == pattern or (
pattern.startswith("*") and host.endswith(pattern[1:])
):
if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
is_valid_host = True
break
elif "www." + host == pattern:

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import io
import math
import sys
@@ -16,7 +18,7 @@ warnings.warn(
)
def build_environ(scope: Scope, body: bytes) -> typing.Dict[str, typing.Any]:
def build_environ(scope: Scope, body: bytes) -> dict[str, typing.Any]:
"""
Builds a scope and request body into a WSGI environ object.
"""
@@ -87,9 +89,7 @@ class WSGIResponder:
self.scope = scope
self.status = None
self.response_headers = None
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(
math.inf
)
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
self.response_started = False
self.exc_info: typing.Any = None
@@ -117,11 +117,11 @@ class WSGIResponder:
def start_response(
self,
status: str,
response_headers: typing.List[typing.Tuple[str, str]],
response_headers: list[tuple[str, str]],
exc_info: typing.Any = None,
) -> None:
self.exc_info = exc_info
if not self.response_started:
if not self.response_started: # pragma: no branch
self.response_started = True
status_code_string, _ = status.split(" ", 1)
status_code = int(status_code_string)
@@ -140,7 +140,7 @@ class WSGIResponder:
def wsgi(
self,
environ: typing.Dict[str, typing.Any],
environ: dict[str, typing.Any],
start_response: typing.Callable[..., typing.Any],
) -> None:
for chunk in self.app(environ, start_response):
@@ -149,6 +149,4 @@ class WSGIResponder:
{"type": "http.response.body", "body": chunk, "more_body": True},
)
anyio.from_thread.run(
self.stream_send.send, {"type": "http.response.body", "body": b""}
)
anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""})

View File

@@ -12,14 +12,19 @@ from starlette.exceptions import HTTPException
from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
from starlette.types import Message, Receive, Scope, Send
try:
from multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: nocover
parse_options_header = None
if typing.TYPE_CHECKING:
from python_multipart.multipart import parse_options_header
from starlette.applications import Starlette
from starlette.routing import Router
else:
try:
try:
from python_multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
from multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
parse_options_header = None
SERVER_PUSH_HEADERS_TO_COPY = {
@@ -43,7 +48,7 @@ def cookie_parser(cookie_string: str) -> dict[str, str]:
Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
on an outdated spec and will fail on lots of input we want to support
"""
cookie_dict: typing.Dict[str, str] = {}
cookie_dict: dict[str, str] = {}
for chunk in cookie_string.split(";"):
if "=" in chunk:
key, val = chunk.split("=", 1)
@@ -93,7 +98,7 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
@property
def url(self) -> URL:
if not hasattr(self, "_url"):
if not hasattr(self, "_url"): # pragma: no branch
self._url = URL(scope=self.scope)
return self._url
@@ -104,9 +109,7 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
# This is used by request.url_for, it might be used inside a Mount which
# would have its own child scope with its own root_path, but the base URL
# for url_for should still be the top level app root path.
app_root_path = base_url_scope.get(
"app_root_path", base_url_scope.get("root_path", "")
)
app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", ""))
path = app_root_path
if not path.endswith("/"):
path += "/"
@@ -124,7 +127,7 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
@property
def query_params(self) -> QueryParams:
if not hasattr(self, "_query_params"):
if not hasattr(self, "_query_params"): # pragma: no branch
self._query_params = QueryParams(self.scope["query_string"])
return self._query_params
@@ -135,7 +138,7 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
@property
def cookies(self) -> dict[str, str]:
if not hasattr(self, "_cookies"):
cookies: typing.Dict[str, str] = {}
cookies: dict[str, str] = {}
cookie_header = self.headers.get("cookie")
if cookie_header:
@@ -145,7 +148,7 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
@property
def client(self) -> Address | None:
# client is a 2 item tuple of (host, port), None or missing
# client is a 2 item tuple of (host, port), None if missing
host_port = self.scope.get("client")
if host_port is not None:
return Address(*host_port)
@@ -153,23 +156,17 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
@property
def session(self) -> dict[str, typing.Any]:
assert (
"session" in self.scope
), "SessionMiddleware must be installed to access request.session"
assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
return self.scope["session"] # type: ignore[no-any-return]
@property
def auth(self) -> typing.Any:
assert (
"auth" in self.scope
), "AuthenticationMiddleware must be installed to access request.auth"
assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
return self.scope["auth"]
@property
def user(self) -> typing.Any:
assert (
"user" in self.scope
), "AuthenticationMiddleware must be installed to access request.user"
assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
return self.scope["user"]
@property
@@ -183,8 +180,10 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
return self._state
def url_for(self, name: str, /, **path_params: typing.Any) -> URL:
router: Router = self.scope["router"]
url_path = router.url_path_for(name, **path_params)
url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app")
if url_path_provider is None:
raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.")
url_path = url_path_provider.url_path_for(name, **path_params)
return url_path.make_absolute_url(base_url=self.base_url)
@@ -197,11 +196,9 @@ async def empty_send(message: Message) -> typing.NoReturn:
class Request(HTTPConnection):
_form: typing.Optional[FormData]
_form: FormData | None
def __init__(
self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
):
def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
super().__init__(scope)
assert scope["type"] == "http"
self._receive = receive
@@ -233,29 +230,33 @@ class Request(HTTPConnection):
self._stream_consumed = True
if body:
yield body
elif message["type"] == "http.disconnect":
elif message["type"] == "http.disconnect": # pragma: no branch
self._is_disconnected = True
raise ClientDisconnect()
yield b""
async def body(self) -> bytes:
if not hasattr(self, "_body"):
chunks: "typing.List[bytes]" = []
chunks: list[bytes] = []
async for chunk in self.stream():
chunks.append(chunk)
self._body = b"".join(chunks)
return self._body
async def json(self) -> typing.Any:
if not hasattr(self, "_json"):
if not hasattr(self, "_json"): # pragma: no branch
body = await self.body()
self._json = json.loads(body)
return self._json
async def _get_form(
self, *, max_files: int | float = 1000, max_fields: int | float = 1000
self,
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024,
) -> FormData:
if self._form is None:
if self._form is None: # pragma: no branch
assert (
parse_options_header is not None
), "The `python-multipart` library must be installed to use form parsing."
@@ -269,6 +270,7 @@ class Request(HTTPConnection):
self.stream(),
max_files=max_files,
max_fields=max_fields,
max_part_size=max_part_size,
)
self._form = await multipart_parser.parse()
except MultiPartException as exc:
@@ -283,14 +285,18 @@ class Request(HTTPConnection):
return self._form
def form(
self, *, max_files: int | float = 1000, max_fields: int | float = 1000
self,
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024,
) -> AwaitableOrContextManager[FormData]:
return AwaitableOrContextManagerWrapper(
self._get_form(max_files=max_files, max_fields=max_fields)
self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size)
)
async def close(self) -> None:
if self._form is not None:
if self._form is not None: # pragma: no branch
await self._form.close()
async def is_disconnected(self) -> bool:
@@ -309,12 +315,8 @@ class Request(HTTPConnection):
async def send_push_promise(self, path: str) -> None:
if "http.response.push" in self.scope.get("extensions", {}):
raw_headers: "typing.List[typing.Tuple[bytes, bytes]]" = []
raw_headers: list[tuple[bytes, bytes]] = []
for name in SERVER_PUSH_HEADERS_TO_COPY:
for value in self.headers.getlist(name):
raw_headers.append(
(name.encode("latin-1"), value.encode("latin-1"))
)
await self._send(
{"type": "http.response.push", "path": path, "headers": raw_headers}
)
raw_headers.append((name.encode("latin-1"), value.encode("latin-1")))
await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import hashlib
import http.cookies
import json
import os
import re
import stat
import typing
import warnings
@@ -10,15 +12,16 @@ from datetime import datetime
from email.utils import format_datetime, formatdate
from functools import partial
from mimetypes import guess_type
from secrets import token_hex
from urllib.parse import quote
import anyio
import anyio.to_thread
from starlette._compat import md5_hexdigest
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, MutableHeaders
from starlette.datastructures import URL, Headers, MutableHeaders
from starlette.requests import ClientDisconnect
from starlette.types import Receive, Scope, Send
@@ -41,10 +44,10 @@ class Response:
self.body = self.render(content)
self.init_headers(headers)
def render(self, content: typing.Any) -> bytes:
def render(self, content: typing.Any) -> bytes | memoryview:
if content is None:
return b""
if isinstance(content, bytes):
if isinstance(content, (bytes, memoryview)):
return content
return content.encode(self.charset) # type: ignore
@@ -54,10 +57,7 @@ class Response:
populate_content_length = True
populate_content_type = True
else:
raw_headers = [
(k.lower().encode("latin-1"), v.encode("latin-1"))
for k, v in headers.items()
]
raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()]
keys = [h[0] for h in raw_headers]
populate_content_length = b"content-length" not in keys
populate_content_type = b"content-type" not in keys
@@ -73,10 +73,7 @@ class Response:
content_type = self.media_type
if content_type is not None and populate_content_type:
if (
content_type.startswith("text/")
and "charset=" not in content_type.lower()
):
if content_type.startswith("text/") and "charset=" not in content_type.lower():
content_type += "; charset=" + self.charset
raw_headers.append((b"content-type", content_type.encode("latin-1")))
@@ -94,7 +91,7 @@ class Response:
value: str = "",
max_age: int | None = None,
expires: datetime | str | int | None = None,
path: str = "/",
path: str | None = "/",
domain: str | None = None,
secure: bool = False,
httponly: bool = False,
@@ -148,14 +145,15 @@ class Response:
)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
prefix = "websocket." if scope["type"] == "websocket" else ""
await send(
{
"type": "http.response.start",
"type": prefix + "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
await send({"type": "http.response.body", "body": self.body})
await send({"type": prefix + "http.response.body", "body": self.body})
if self.background is not None:
await self.background()
@@ -200,13 +198,11 @@ class RedirectResponse(Response):
headers: typing.Mapping[str, str] | None = None,
background: BackgroundTask | None = None,
) -> None:
super().__init__(
content=b"", status_code=status_code, headers=headers, background=background
)
super().__init__(content=b"", status_code=status_code, headers=headers, background=background)
self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
Content = typing.Union[str, bytes]
Content = typing.Union[str, bytes, memoryview]
SyncContentStream = typing.Iterable[Content]
AsyncContentStream = typing.AsyncIterable[Content]
ContentStream = typing.Union[AsyncContentStream, SyncContentStream]
@@ -247,26 +243,47 @@ class StreamingResponse(Response):
}
)
async for chunk in self.body_iterator:
if not isinstance(chunk, bytes):
if not isinstance(chunk, (bytes, memoryview)):
chunk = chunk.encode(self.charset)
await send({"type": "http.response.body", "body": chunk, "more_body": True})
await send({"type": "http.response.body", "body": b"", "more_body": False})
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async with anyio.create_task_group() as task_group:
spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split(".")))
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
await func()
task_group.cancel_scope.cancel()
if spec_version >= (2, 4):
try:
await self.stream_response(send)
except OSError:
raise ClientDisconnect()
else:
async with anyio.create_task_group() as task_group:
task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
await func()
task_group.cancel_scope.cancel()
task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))
if self.background is not None:
await self.background()
class MalformedRangeHeader(Exception):
def __init__(self, content: str = "Malformed range header.") -> None:
self.content = content
class RangeNotSatisfiable(Exception):
def __init__(self, max_size: int) -> None:
self.max_size = max_size
_RANGE_PATTERN = re.compile(r"(\d*)-(\d*)")
class FileResponse(Response):
chunk_size = 64 * 1024
@@ -295,16 +312,13 @@ class FileResponse(Response):
self.media_type = media_type
self.background = background
self.init_headers(headers)
self.headers.setdefault("accept-ranges", "bytes")
if self.filename is not None:
content_disposition_filename = quote(self.filename)
if content_disposition_filename != self.filename:
content_disposition = "{}; filename*=utf-8''{}".format(
content_disposition_type, content_disposition_filename
)
content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}"
else:
content_disposition = '{}; filename="{}"'.format(
content_disposition_type, self.filename
)
content_disposition = f'{content_disposition_type}; filename="{self.filename}"'
self.headers.setdefault("content-disposition", content_disposition)
self.stat_result = stat_result
if stat_result is not None:
@@ -314,13 +328,14 @@ class FileResponse(Response):
content_length = str(stat_result.st_size)
last_modified = formatdate(stat_result.st_mtime, usegmt=True)
etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size)
etag = f'"{md5_hexdigest(etag_base.encode(), usedforsecurity=False)}"'
etag = f'"{hashlib.md5(etag_base.encode(), usedforsecurity=False).hexdigest()}"'
self.headers.setdefault("content-length", content_length)
self.headers.setdefault("last-modified", last_modified)
self.headers.setdefault("etag", etag)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
send_header_only: bool = scope["method"].upper() == "HEAD"
if self.stat_result is None:
try:
stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
@@ -331,29 +346,192 @@ class FileResponse(Response):
mode = stat_result.st_mode
if not stat.S_ISREG(mode):
raise RuntimeError(f"File at path {self.path} is not a file.")
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
if scope["method"].upper() == "HEAD":
else:
stat_result = self.stat_result
headers = Headers(scope=scope)
http_range = headers.get("range")
http_if_range = headers.get("if-range")
if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)):
await self._handle_simple(send, send_header_only)
else:
try:
ranges = self._parse_range_header(http_range, stat_result.st_size)
except MalformedRangeHeader as exc:
return await PlainTextResponse(exc.content, status_code=400)(scope, receive, send)
except RangeNotSatisfiable as exc:
response = PlainTextResponse(status_code=416, headers={"Content-Range": f"*/{exc.max_size}"})
return await response(scope, receive, send)
if len(ranges) == 1:
start, end = ranges[0]
await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only)
else:
await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only)
if self.background is not None:
await self.background()
async def _handle_simple(self, send: Send, send_header_only: bool) -> None:
await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers})
if send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
elif "extensions" in scope and "http.response.pathsend" in scope["extensions"]:
await send({"type": "http.response.pathsend", "path": str(self.path)})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
more_body = True
while more_body:
chunk = await file.read(self.chunk_size)
more_body = len(chunk) == self.chunk_size
await send(
{
"type": "http.response.body",
"body": chunk,
"more_body": more_body,
}
)
if self.background is not None:
await self.background()
await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
async def _handle_single_range(
self, send: Send, start: int, end: int, file_size: int, send_header_only: bool
) -> None:
self.headers["content-range"] = f"bytes {start}-{end - 1}/{file_size}"
self.headers["content-length"] = str(end - start)
await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers})
if send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
await file.seek(start)
more_body = True
while more_body:
chunk = await file.read(min(self.chunk_size, end - start))
start += len(chunk)
more_body = len(chunk) == self.chunk_size and start < end
await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
async def _handle_multiple_ranges(
self,
send: Send,
ranges: list[tuple[int, int]],
file_size: int,
send_header_only: bool,
) -> None:
# In firefox and chrome, they use boundary with 95-96 bits entropy (that's roughly 13 bytes).
boundary = token_hex(13)
content_length, header_generator = self.generate_multipart(
ranges, boundary, file_size, self.headers["content-type"]
)
self.headers["content-range"] = f"multipart/byteranges; boundary={boundary}"
self.headers["content-length"] = str(content_length)
await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers})
if send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
for start, end in ranges:
await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True})
await file.seek(start)
while start < end:
chunk = await file.read(min(self.chunk_size, end - start))
start += len(chunk)
await send({"type": "http.response.body", "body": chunk, "more_body": True})
await send({"type": "http.response.body", "body": b"\n", "more_body": True})
await send(
{
"type": "http.response.body",
"body": f"\n--{boundary}--\n".encode("latin-1"),
"more_body": False,
}
)
def _should_use_range(self, http_if_range: str) -> bool:
return http_if_range == self.headers["last-modified"] or http_if_range == self.headers["etag"]
@staticmethod
def _parse_range_header(http_range: str, file_size: int) -> list[tuple[int, int]]:
ranges: list[tuple[int, int]] = []
try:
units, range_ = http_range.split("=", 1)
except ValueError:
raise MalformedRangeHeader()
units = units.strip().lower()
if units != "bytes":
raise MalformedRangeHeader("Only support bytes range")
ranges = [
(
int(_[0]) if _[0] else file_size - int(_[1]),
int(_[1]) + 1 if _[0] and _[1] and int(_[1]) < file_size else file_size,
)
for _ in _RANGE_PATTERN.findall(range_)
if _ != ("", "")
]
if len(ranges) == 0:
raise MalformedRangeHeader("Range header: range must be requested")
if any(not (0 <= start < file_size) for start, _ in ranges):
raise RangeNotSatisfiable(file_size)
if any(start > end for start, end in ranges):
raise MalformedRangeHeader("Range header: start must be less than end")
if len(ranges) == 1:
return ranges
# Merge ranges
result: list[tuple[int, int]] = []
for start, end in ranges:
for p in range(len(result)):
p_start, p_end = result[p]
if start > p_end:
continue
elif end < p_start:
result.insert(p, (start, end)) # THIS IS NOT REACHED!
break
else:
result[p] = (min(start, p_start), max(end, p_end))
break
else:
result.append((start, end))
return result
def generate_multipart(
self,
ranges: typing.Sequence[tuple[int, int]],
boundary: str,
max_size: int,
content_type: str,
) -> tuple[int, typing.Callable[[int, int], bytes]]:
r"""
Multipart response headers generator.
```
--{boundary}\n
Content-Type: {content_type}\n
Content-Range: bytes {start}-{end-1}/{max_size}\n
\n
..........content...........\n
--{boundary}\n
Content-Type: {content_type}\n
Content-Range: bytes {start}-{end-1}/{max_size}\n
\n
..........content...........\n
--{boundary}--\n
```
"""
boundary_len = len(boundary)
static_header_part_len = 44 + boundary_len + len(content_type) + len(str(max_size))
content_length = sum(
(len(str(start)) + len(str(end - 1)) + static_header_part_len) # Headers
+ (end - start) # Content
for start, end in ranges
) + (
5 + boundary_len # --boundary--\n
)
return (
content_length,
lambda start, end: (
f"--{boundary}\n"
f"Content-Type: {content_type}\n"
f"Content-Range: bytes {start}-{end-1}/{max_size}\n"
"\n"
).encode("latin-1"),
)

View File

@@ -47,8 +47,7 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
including those wrapped in functools.partial objects.
"""
warnings.warn(
"iscoroutinefunction_or_partial is deprecated, "
"and will be removed in a future release.",
"iscoroutinefunction_or_partial is deprecated, and will be removed in a future release.",
DeprecationWarning,
)
while isinstance(obj, functools.partial):
@@ -57,23 +56,21 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
def request_response(
func: typing.Callable[
[Request], typing.Union[typing.Awaitable[Response], Response]
],
func: typing.Callable[[Request], typing.Awaitable[Response] | Response],
) -> ASGIApp:
"""
Takes a function or coroutine `func(request) -> response`,
and returns an ASGI application.
"""
f: typing.Callable[[Request], typing.Awaitable[Response]] = (
func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore
)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive, send)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
if is_async_callable(func):
response = await func(request)
else:
response = await run_in_threadpool(func, request)
response = await f(request)
await response(scope, receive, send)
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
@@ -101,9 +98,7 @@ def websocket_session(
def get_name(endpoint: typing.Callable[..., typing.Any]) -> str:
if inspect.isroutine(endpoint) or inspect.isclass(endpoint):
return endpoint.__name__
return endpoint.__class__.__name__
return getattr(endpoint, "__name__", endpoint.__class__.__name__)
def replace_params(
@@ -147,9 +142,7 @@ def compile_path(
for match in PARAM_REGEX.finditer(path):
param_name, convertor_type = match.groups("str")
convertor_type = convertor_type.lstrip(":")
assert (
convertor_type in CONVERTOR_TYPES
), f"Unknown path convertor '{convertor_type}'"
assert convertor_type in CONVERTOR_TYPES, f"Unknown path convertor '{convertor_type}'"
convertor = CONVERTOR_TYPES[convertor_type]
path_regex += re.escape(path[idx : match.start()])
@@ -203,7 +196,7 @@ class BaseRoute:
if scope["type"] == "http":
response = PlainTextResponse("Not Found", status_code=404)
await response(scope, receive, send)
elif scope["type"] == "websocket":
elif scope["type"] == "websocket": # pragma: no branch
websocket_close = WebSocketClose()
await websocket_close(scope, receive, send)
return
@@ -243,7 +236,7 @@ class Route(BaseRoute):
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)
self.app = cls(self.app, *args, **kwargs)
if methods is None:
self.methods = None
@@ -255,7 +248,7 @@ class Route(BaseRoute):
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
def matches(self, scope: Scope) -> tuple[Match, Scope]:
path_params: "typing.Dict[str, typing.Any]"
path_params: dict[str, typing.Any]
if scope["type"] == "http":
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
@@ -279,9 +272,7 @@ class Route(BaseRoute):
if name != self.name or seen_params != expected_params:
raise NoMatchFound(name, path_params)
path, remaining_params = replace_params(
self.path_format, self.param_convertors, path_params
)
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
assert not remaining_params
return URLPath(path=path, protocol="http")
@@ -291,9 +282,7 @@ class Route(BaseRoute):
if "app" in scope:
raise HTTPException(status_code=405, headers=headers)
else:
response = PlainTextResponse(
"Method Not Allowed", status_code=405, headers=headers
)
response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
await response(scope, receive, send)
else:
await self.app(scope, receive, send)
@@ -339,12 +328,12 @@ class WebSocketRoute(BaseRoute):
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)
self.app = cls(self.app, *args, **kwargs)
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
def matches(self, scope: Scope) -> tuple[Match, Scope]:
path_params: "typing.Dict[str, typing.Any]"
path_params: dict[str, typing.Any]
if scope["type"] == "websocket":
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
@@ -365,9 +354,7 @@ class WebSocketRoute(BaseRoute):
if name != self.name or seen_params != expected_params:
raise NoMatchFound(name, path_params)
path, remaining_params = replace_params(
self.path_format, self.param_convertors, path_params
)
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
assert not remaining_params
return URLPath(path=path, protocol="websocket")
@@ -375,11 +362,7 @@ class WebSocketRoute(BaseRoute):
await self.app(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, WebSocketRoute)
and self.path == other.path
and self.endpoint == other.endpoint
)
return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint
def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})"
@@ -396,9 +379,7 @@ class Mount(BaseRoute):
middleware: typing.Sequence[Middleware] | None = None,
) -> None:
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
assert (
app is not None or routes is not None
), "Either 'app=...', or 'routes=' must be specified"
assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified"
self.path = path.rstrip("/")
if app is not None:
self._base_app: ASGIApp = app
@@ -407,19 +388,17 @@ class Mount(BaseRoute):
self.app = self._base_app
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)
self.app = cls(self.app, *args, **kwargs)
self.name = name
self.path_regex, self.path_format, self.param_convertors = compile_path(
self.path + "/{path:path}"
)
self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}")
@property
def routes(self) -> list[BaseRoute]:
return getattr(self._base_app, "routes", [])
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
path_params: "typing.Dict[str, typing.Any]"
if scope["type"] in ("http", "websocket"):
def matches(self, scope: Scope) -> tuple[Match, Scope]:
path_params: dict[str, typing.Any]
if scope["type"] in ("http", "websocket"): # pragma: no branch
root_path = scope.get("root_path", "")
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
@@ -454,9 +433,7 @@ class Mount(BaseRoute):
if self.name is not None and name == self.name and "path" in path_params:
# 'name' matches "<mount_name>".
path_params["path"] = path_params["path"].lstrip("/")
path, remaining_params = replace_params(
self.path_format, self.param_convertors, path_params
)
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
if not remaining_params:
return URLPath(path=path)
elif self.name is None or name.startswith(self.name + ":"):
@@ -468,17 +445,13 @@ class Mount(BaseRoute):
remaining_name = name[len(self.name) + 1 :]
path_kwarg = path_params.get("path")
path_params["path"] = ""
path_prefix, remaining_params = replace_params(
self.path_format, self.param_convertors, path_params
)
path_prefix, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
if path_kwarg is not None:
remaining_params["path"] = path_kwarg
for route in self.routes or []:
try:
url = route.url_path_for(remaining_name, **remaining_params)
return URLPath(
path=path_prefix.rstrip("/") + str(url), protocol=url.protocol
)
return URLPath(path=path_prefix.rstrip("/") + str(url), protocol=url.protocol)
except NoMatchFound:
pass
raise NoMatchFound(name, path_params)
@@ -487,11 +460,7 @@ class Mount(BaseRoute):
await self.app(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, Mount)
and self.path == other.path
and self.app == other.app
)
return isinstance(other, Mount) and self.path == other.path and self.app == other.app
def __repr__(self) -> str:
class_name = self.__class__.__name__
@@ -512,7 +481,7 @@ class Host(BaseRoute):
return getattr(self.app, "routes", [])
def matches(self, scope: Scope) -> tuple[Match, Scope]:
if scope["type"] in ("http", "websocket"):
if scope["type"] in ("http", "websocket"): # pragma:no branch
headers = Headers(scope=scope)
host = headers.get("host", "").split(":")[0]
match = self.host_regex.match(host)
@@ -530,9 +499,7 @@ class Host(BaseRoute):
if self.name is not None and name == self.name and "path" in path_params:
# 'name' matches "<mount_name>".
path = path_params.pop("path")
host, remaining_params = replace_params(
self.host_format, self.param_convertors, path_params
)
host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
if not remaining_params:
return URLPath(path=path, host=host)
elif self.name is None or name.startswith(self.name + ":"):
@@ -542,9 +509,7 @@ class Host(BaseRoute):
else:
# 'name' matches "<mount_name>:<child_name>".
remaining_name = name[len(self.name) + 1 :]
host, remaining_params = replace_params(
self.host_format, self.param_convertors, path_params
)
host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
for route in self.routes or []:
try:
url = route.url_path_for(remaining_name, **remaining_params)
@@ -557,11 +522,7 @@ class Host(BaseRoute):
await self.app(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, Host)
and self.host == other.host
and self.app == other.app
)
return isinstance(other, Host) and self.host == other.host and self.app == other.app
def __repr__(self) -> str:
class_name = self.__class__.__name__
@@ -589,9 +550,7 @@ class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
def _wrap_gen_lifespan_context(
lifespan_context: typing.Callable[
[typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]
],
lifespan_context: typing.Callable[[typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]],
) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]:
cmgr = contextlib.contextmanager(lifespan_context)
@@ -734,9 +693,7 @@ class Router:
async with self.lifespan_context(app) as maybe_state:
if maybe_state is not None:
if "state" not in scope:
raise RuntimeError(
'The server does not support "state" in the lifespan scope.'
)
raise RuntimeError('The server does not support "state" in the lifespan scope.')
scope["state"].update(maybe_state)
await send({"type": "lifespan.startup.complete"})
started = True
@@ -810,15 +767,11 @@ class Router:
def __eq__(self, other: typing.Any) -> bool:
return isinstance(other, Router) and self.routes == other.routes
def mount(
self, path: str, app: ASGIApp, name: str | None = None
) -> None: # pragma: nocover
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
route = Mount(path, app=app, name=name)
self.routes.append(route)
def host(
self, host: str, app: ASGIApp, name: str | None = None
) -> None: # pragma: no cover
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
route = Host(host, app=app, name=name)
self.routes.append(route)
@@ -829,7 +782,7 @@ class Router:
methods: list[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
) -> None: # pragma: nocover
) -> None: # pragma: no cover
route = Route(
path,
endpoint=endpoint,
@@ -864,11 +817,11 @@ class Router:
"""
warnings.warn(
"The `route` decorator is deprecated, and will be removed in version 1.0.0."
"Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.", # noqa: E501
"Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_route(
path,
func,
@@ -889,20 +842,18 @@ class Router:
>>> app = Starlette(routes=routes)
"""
warnings.warn(
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to " # noqa: E501
"https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501
"The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to "
"https://www.starlette.io/routing/#websocket-routing for the recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_websocket_route(path, func, name=name)
return func
return decorator
def add_event_handler(
self, event_type: str, func: typing.Callable[[], typing.Any]
) -> None: # pragma: no cover
def add_event_handler(self, event_type: str, func: typing.Callable[[], typing.Any]) -> None: # pragma: no cover
assert event_type in ("startup", "shutdown")
if event_type == "startup":
@@ -912,12 +863,12 @@ class Router:
def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
warnings.warn(
"The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
"The `on_event` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://www.starlette.io/lifespan/ for recommended approach.",
DeprecationWarning,
)
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_event_handler(event_type, func)
return func

View File

@@ -10,7 +10,7 @@ from starlette.routing import BaseRoute, Host, Mount, Route
try:
import yaml
except ModuleNotFoundError: # pragma: nocover
except ModuleNotFoundError: # pragma: no cover
yaml = None # type: ignore[assignment]
@@ -19,9 +19,7 @@ class OpenAPIResponse(Response):
def render(self, content: typing.Any) -> bytes:
assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
assert isinstance(
content, dict
), "The schema passed to OpenAPIResponse should be a dictionary."
assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary."
return yaml.dump(content, default_flow_style=False).encode("utf-8")
@@ -31,6 +29,9 @@ class EndpointInfo(typing.NamedTuple):
func: typing.Callable[..., typing.Any]
_remove_converter_pattern = re.compile(r":\w+}")
class BaseSchemaGenerator:
def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]:
raise NotImplementedError() # pragma: no cover
@@ -73,9 +74,7 @@ class BaseSchemaGenerator:
for method in route.methods or ["GET"]:
if method == "HEAD":
continue
endpoints_info.append(
EndpointInfo(path, method.lower(), route.endpoint)
)
endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint))
else:
path = self._remove_converter(route.path)
for method in ["get", "post", "put", "patch", "delete", "options"]:
@@ -93,11 +92,9 @@ class BaseSchemaGenerator:
Route("/users/{id:int}", endpoint=get_user, methods=["GET"])
Should be represented as `/users/{id}` in the OpenAPI schema.
"""
return re.sub(r":\w+}", "}", path)
return _remove_converter_pattern.sub("}", path)
def parse_docstring(
self, func_or_method: typing.Callable[..., typing.Any]
) -> dict[str, typing.Any]:
def parse_docstring(self, func_or_method: typing.Callable[..., typing.Any]) -> dict[str, typing.Any]:
"""
Given a function, parse the docstring as YAML and return a dictionary of info.
"""

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import errno
import importlib.util
import os
import stat
@@ -31,11 +32,7 @@ class NotModifiedResponse(Response):
def __init__(self, headers: Headers):
super().__init__(
status_code=304,
headers={
name: value
for name, value in headers.items()
if name in self.NOT_MODIFIED_HEADERS
},
headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS},
)
@@ -79,9 +76,7 @@ class StaticFiles:
spec = importlib.util.find_spec(package)
assert spec is not None, f"Package {package!r} could not be found."
assert spec.origin is not None, f"Package {package!r} could not be found."
package_directory = os.path.normpath(
os.path.join(spec.origin, "..", statics_dir)
)
package_directory = os.path.normpath(os.path.join(spec.origin, "..", statics_dir))
assert os.path.isdir(
package_directory
), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
@@ -109,7 +104,7 @@ class StaticFiles:
with OS specific path separators, and any '..', '.' components removed.
"""
route_path = get_route_path(scope)
return os.path.normpath(os.path.join(*route_path.split("/"))) # noqa: E501
return os.path.normpath(os.path.join(*route_path.split("/")))
async def get_response(self, path: str, scope: Scope) -> Response:
"""
@@ -119,13 +114,15 @@ class StaticFiles:
raise HTTPException(status_code=405)
try:
full_path, stat_result = await anyio.to_thread.run_sync(
self.lookup_path, path
)
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, path)
except PermissionError:
raise HTTPException(status_code=401)
except OSError:
raise
except OSError as exc:
# Filename is too long, so it can't be a valid static file.
if exc.errno == errno.ENAMETOOLONG:
raise HTTPException(status_code=404)
raise exc
if stat_result and stat.S_ISREG(stat_result.st_mode):
# We have a static file to serve.
@@ -135,9 +132,7 @@ class StaticFiles:
# We're in HTML mode, and have got a directory URL.
# Check if we have 'index.html' file to serve.
index_path = os.path.join(path, "index.html")
full_path, stat_result = await anyio.to_thread.run_sync(
self.lookup_path, index_path
)
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, index_path)
if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
if not scope["path"].endswith("/"):
# Directory URLs should redirect to always end in "/".
@@ -148,9 +143,7 @@ class StaticFiles:
if self.html:
# Check for '404.html' if we're in HTML mode.
full_path, stat_result = await anyio.to_thread.run_sync(
self.lookup_path, "404.html"
)
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, "404.html")
if stat_result and stat.S_ISREG(stat_result.st_mode):
return FileResponse(full_path, stat_result=stat_result, status_code=404)
raise HTTPException(status_code=404)
@@ -162,10 +155,9 @@ class StaticFiles:
full_path = os.path.abspath(joined_path)
else:
full_path = os.path.realpath(joined_path)
directory = os.path.realpath(directory)
if os.path.commonpath([full_path, directory]) != directory:
# Don't allow misbehaving clients to break out of the static files
# directory.
directory = os.path.realpath(directory)
if os.path.commonpath([full_path, directory]) != str(directory):
# Don't allow misbehaving clients to break out of the static files directory.
continue
try:
return full_path, os.stat(full_path)
@@ -182,9 +174,7 @@ class StaticFiles:
) -> Response:
request_headers = Headers(scope=scope)
response = FileResponse(
full_path, status_code=status_code, stat_result=stat_result
)
response = FileResponse(full_path, status_code=status_code, stat_result=stat_result)
if self.is_not_modified(response.headers, request_headers):
return NotModifiedResponse(response.headers)
return response
@@ -201,17 +191,11 @@ class StaticFiles:
try:
stat_result = await anyio.to_thread.run_sync(os.stat, self.directory)
except FileNotFoundError:
raise RuntimeError(
f"StaticFiles directory '{self.directory}' does not exist."
)
raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.")
if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)):
raise RuntimeError(
f"StaticFiles path '{self.directory}' is not a directory."
)
raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.")
def is_not_modified(
self, response_headers: Headers, request_headers: Headers
) -> bool:
def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool:
"""
Given the request and response headers, return `True` if an HTTP
"Not Modified" response could be returned instead.
@@ -227,11 +211,7 @@ class StaticFiles:
try:
if_modified_since = parsedate(request_headers["if-modified-since"])
last_modified = parsedate(response_headers["last-modified"])
if (
if_modified_since is not None
and last_modified is not None
and if_modified_since >= last_modified
):
if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified:
return True
except KeyError:
pass

View File

@@ -5,91 +5,9 @@ https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml
And RFC 2324 - https://tools.ietf.org/html/rfc2324
"""
from __future__ import annotations
import warnings
__all__ = (
"HTTP_100_CONTINUE",
"HTTP_101_SWITCHING_PROTOCOLS",
"HTTP_102_PROCESSING",
"HTTP_103_EARLY_HINTS",
"HTTP_200_OK",
"HTTP_201_CREATED",
"HTTP_202_ACCEPTED",
"HTTP_203_NON_AUTHORITATIVE_INFORMATION",
"HTTP_204_NO_CONTENT",
"HTTP_205_RESET_CONTENT",
"HTTP_206_PARTIAL_CONTENT",
"HTTP_207_MULTI_STATUS",
"HTTP_208_ALREADY_REPORTED",
"HTTP_226_IM_USED",
"HTTP_300_MULTIPLE_CHOICES",
"HTTP_301_MOVED_PERMANENTLY",
"HTTP_302_FOUND",
"HTTP_303_SEE_OTHER",
"HTTP_304_NOT_MODIFIED",
"HTTP_305_USE_PROXY",
"HTTP_306_RESERVED",
"HTTP_307_TEMPORARY_REDIRECT",
"HTTP_308_PERMANENT_REDIRECT",
"HTTP_400_BAD_REQUEST",
"HTTP_401_UNAUTHORIZED",
"HTTP_402_PAYMENT_REQUIRED",
"HTTP_403_FORBIDDEN",
"HTTP_404_NOT_FOUND",
"HTTP_405_METHOD_NOT_ALLOWED",
"HTTP_406_NOT_ACCEPTABLE",
"HTTP_407_PROXY_AUTHENTICATION_REQUIRED",
"HTTP_408_REQUEST_TIMEOUT",
"HTTP_409_CONFLICT",
"HTTP_410_GONE",
"HTTP_411_LENGTH_REQUIRED",
"HTTP_412_PRECONDITION_FAILED",
"HTTP_413_REQUEST_ENTITY_TOO_LARGE",
"HTTP_414_REQUEST_URI_TOO_LONG",
"HTTP_415_UNSUPPORTED_MEDIA_TYPE",
"HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE",
"HTTP_417_EXPECTATION_FAILED",
"HTTP_418_IM_A_TEAPOT",
"HTTP_421_MISDIRECTED_REQUEST",
"HTTP_422_UNPROCESSABLE_ENTITY",
"HTTP_423_LOCKED",
"HTTP_424_FAILED_DEPENDENCY",
"HTTP_425_TOO_EARLY",
"HTTP_426_UPGRADE_REQUIRED",
"HTTP_428_PRECONDITION_REQUIRED",
"HTTP_429_TOO_MANY_REQUESTS",
"HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE",
"HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS",
"HTTP_500_INTERNAL_SERVER_ERROR",
"HTTP_501_NOT_IMPLEMENTED",
"HTTP_502_BAD_GATEWAY",
"HTTP_503_SERVICE_UNAVAILABLE",
"HTTP_504_GATEWAY_TIMEOUT",
"HTTP_505_HTTP_VERSION_NOT_SUPPORTED",
"HTTP_506_VARIANT_ALSO_NEGOTIATES",
"HTTP_507_INSUFFICIENT_STORAGE",
"HTTP_508_LOOP_DETECTED",
"HTTP_510_NOT_EXTENDED",
"HTTP_511_NETWORK_AUTHENTICATION_REQUIRED",
"WS_1000_NORMAL_CLOSURE",
"WS_1001_GOING_AWAY",
"WS_1002_PROTOCOL_ERROR",
"WS_1003_UNSUPPORTED_DATA",
"WS_1005_NO_STATUS_RCVD",
"WS_1006_ABNORMAL_CLOSURE",
"WS_1007_INVALID_FRAME_PAYLOAD_DATA",
"WS_1008_POLICY_VIOLATION",
"WS_1009_MESSAGE_TOO_BIG",
"WS_1010_MANDATORY_EXT",
"WS_1011_INTERNAL_ERROR",
"WS_1012_SERVICE_RESTART",
"WS_1013_TRY_AGAIN_LATER",
"WS_1014_BAD_GATEWAY",
"WS_1015_TLS_HANDSHAKE",
)
HTTP_100_CONTINUE = 100
HTTP_101_SWITCHING_PROTOCOLS = 101
HTTP_102_PROCESSING = 102
@@ -175,26 +93,3 @@ WS_1012_SERVICE_RESTART = 1012
WS_1013_TRY_AGAIN_LATER = 1013
WS_1014_BAD_GATEWAY = 1014
WS_1015_TLS_HANDSHAKE = 1015
__deprecated__ = {"WS_1004_NO_STATUS_RCVD": 1004, "WS_1005_ABNORMAL_CLOSURE": 1005}
def __getattr__(name: str) -> int:
deprecation_changes = {
"WS_1004_NO_STATUS_RCVD": "WS_1005_NO_STATUS_RCVD",
"WS_1005_ABNORMAL_CLOSURE": "WS_1006_ABNORMAL_CLOSURE",
}
deprecated = __deprecated__.get(name)
if deprecated:
warnings.warn(
f"'{name}' is deprecated. Use '{deprecation_changes[name]}' instead.",
category=DeprecationWarning,
stacklevel=3,
)
return deprecated
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
def __dir__() -> list[str]:
return sorted(list(__all__) + list(__deprecated__.keys())) # pragma: no cover

View File

@@ -19,9 +19,9 @@ try:
# adding a type ignore for mypy to let us access an attribute that may not exist
if hasattr(jinja2, "pass_context"):
pass_context = jinja2.pass_context
else: # pragma: nocover
else: # pragma: no cover
pass_context = jinja2.contextfunction # type: ignore[attr-defined]
except ModuleNotFoundError: # pragma: nocover
except ModuleNotFoundError: # pragma: no cover
jinja2 = None # type: ignore[assignment]
@@ -43,7 +43,7 @@ class _TemplateResponse(HTMLResponse):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
request = self.context.get("request", {})
extensions = request.get("extensions", {})
if "http.response.debug" in extensions:
if "http.response.debug" in extensions: # pragma: no branch
await send(
{
"type": "http.response.debug",
@@ -66,58 +66,46 @@ class Jinja2Templates:
@typing.overload
def __init__(
self,
directory: str
| PathLike[typing.AnyStr]
| typing.Sequence[str | PathLike[typing.AnyStr]],
directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]],
*,
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
| None = None,
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
**env_options: typing.Any,
) -> None:
...
) -> None: ...
@typing.overload
def __init__(
self,
*,
env: jinja2.Environment,
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
| None = None,
) -> None:
...
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
) -> None: ...
def __init__(
self,
directory: str
| PathLike[typing.AnyStr]
| typing.Sequence[str | PathLike[typing.AnyStr]]
| None = None,
directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]] | None = None,
*,
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
| None = None,
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
env: jinja2.Environment | None = None,
**env_options: typing.Any,
) -> None:
if env_options:
warnings.warn(
"Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.", # noqa: E501
"Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.",
DeprecationWarning,
)
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
assert directory or env, "either 'directory' or 'env' arguments must be passed"
assert bool(directory) ^ bool(env), "either 'directory' or 'env' arguments must be passed"
self.context_processors = context_processors or []
if directory is not None:
self.env = self._create_env(directory, **env_options)
elif env is not None:
elif env is not None: # pragma: no branch
self.env = env
self._setup_env_defaults(self.env)
def _create_env(
self,
directory: str
| PathLike[typing.AnyStr]
| typing.Sequence[str | PathLike[typing.AnyStr]],
directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]],
**env_options: typing.Any,
) -> jinja2.Environment:
loader = jinja2.FileSystemLoader(directory)
@@ -129,7 +117,7 @@ class Jinja2Templates:
def _setup_env_defaults(self, env: jinja2.Environment) -> None:
@pass_context
def url_for(
context: typing.Dict[str, typing.Any],
context: dict[str, typing.Any],
name: str,
/,
**path_params: typing.Any,
@@ -152,8 +140,7 @@ class Jinja2Templates:
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> _TemplateResponse:
...
) -> _TemplateResponse: ...
@typing.overload
def TemplateResponse(
@@ -168,25 +155,19 @@ class Jinja2Templates:
# Deprecated usage
...
def TemplateResponse(
self, *args: typing.Any, **kwargs: typing.Any
) -> _TemplateResponse:
def TemplateResponse(self, *args: typing.Any, **kwargs: typing.Any) -> _TemplateResponse:
if args:
if isinstance(
args[0], str
): # the first argument is template name (old style)
if isinstance(args[0], str): # the first argument is template name (old style)
warnings.warn(
"The `name` is not the first parameter anymore. "
"The first parameter should be the `Request` instance.\n"
'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.', # noqa: E501
'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.',
DeprecationWarning,
)
name = args[0]
context = args[1] if len(args) > 1 else kwargs.get("context", {})
status_code = (
args[2] if len(args) > 2 else kwargs.get("status_code", 200)
)
status_code = args[2] if len(args) > 2 else kwargs.get("status_code", 200)
headers = args[2] if len(args) > 2 else kwargs.get("headers")
media_type = args[3] if len(args) > 3 else kwargs.get("media_type")
background = args[4] if len(args) > 4 else kwargs.get("background")
@@ -198,9 +179,7 @@ class Jinja2Templates:
request = args[0]
name = args[1] if len(args) > 1 else kwargs["name"]
context = args[2] if len(args) > 2 else kwargs.get("context", {})
status_code = (
args[3] if len(args) > 3 else kwargs.get("status_code", 200)
)
status_code = args[3] if len(args) > 3 else kwargs.get("status_code", 200)
headers = args[4] if len(args) > 4 else kwargs.get("headers")
media_type = args[5] if len(args) > 5 else kwargs.get("media_type")
background = args[6] if len(args) > 6 else kwargs.get("background")
@@ -208,7 +187,7 @@ class Jinja2Templates:
if "request" not in kwargs:
warnings.warn(
"The `TemplateResponse` now requires the `request` argument.\n"
'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.', # noqa: E501
'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.',
DeprecationWarning,
)
if "request" not in kwargs.get("context", {}):

View File

@@ -5,19 +5,15 @@ import inspect
import io
import json
import math
import queue
import sys
import typing
import warnings
from concurrent.futures import Future
from functools import cached_property
from types import GeneratorType
from urllib.parse import unquote, urljoin
import anyio
import anyio.abc
import anyio.from_thread
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from anyio.streams.stapled import StapledObjectStream
from starlette._utils import is_async_callable
@@ -37,16 +33,14 @@ except ModuleNotFoundError: # pragma: no cover
"You can install this with:\n"
" $ pip install httpx\n"
)
_PortalFactoryType = typing.Callable[
[], typing.ContextManager[anyio.abc.BlockingPortal]
]
_PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]]
ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
ASGI2App = typing.Callable[[Scope], ASGIInstance]
ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str]]]
_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]]
def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
@@ -78,6 +72,16 @@ class _Upgrade(Exception):
self.session = session
class WebSocketDenialResponse( # type: ignore[misc]
httpx.Response,
WebSocketDisconnect,
):
"""
A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
`WebSocket` is closed before being accepted with a `send_denial_response()`.
"""
class WebSocketTestSession:
def __init__(
self,
@@ -89,81 +93,60 @@ class WebSocketTestSession:
self.scope = scope
self.accepted_subprotocol = None
self.portal_factory = portal_factory
self._receive_queue: queue.Queue[Message] = queue.Queue()
self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
self.extra_headers = None
def __enter__(self) -> WebSocketTestSession:
self.exit_stack = contextlib.ExitStack()
self.portal = self.exit_stack.enter_context(self.portal_factory())
try:
_: Future[None] = self.portal.start_task_soon(self._run)
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(self.portal_factory())
fut, cs = portal.start_task(self._run)
stack.callback(fut.result)
stack.callback(portal.call, cs.cancel)
self.send({"type": "websocket.connect"})
message = self.receive()
self._raise_on_close(message)
except Exception:
self.exit_stack.close()
raise
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
return self
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
stack.callback(self.close, 1000)
self.exit_stack = stack.pop_all()
return self
@cached_property
def should_close(self) -> anyio.Event:
return anyio.Event()
def __exit__(self, *args: typing.Any) -> bool | None:
return self.exit_stack.__exit__(*args)
async def _notify_close(self) -> None:
self.should_close.set()
def __exit__(self, *args: typing.Any) -> None:
try:
self.close(1000)
finally:
self.portal.start_task_soon(self._notify_close)
self.exit_stack.close()
while not self._send_queue.empty():
message = self._send_queue.get()
if isinstance(message, BaseException):
raise message
async def _run(self) -> None:
async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
"""
The sub-thread in which the websocket session runs.
"""
send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
send_tx, send_rx = send
receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
receive_tx, receive_rx = receive
with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
self._receive_tx = receive_tx
self._send_rx = send_rx
task_status.started(cs)
await self.app(self.scope, receive_rx.receive, send_tx.send)
async def run_app(tg: anyio.abc.TaskGroup) -> None:
try:
await self.app(self.scope, self._asgi_receive, self._asgi_send)
except anyio.get_cancelled_exc_class():
...
except BaseException as exc:
self._send_queue.put(exc)
raise
finally:
tg.cancel_scope.cancel()
async with anyio.create_task_group() as tg:
tg.start_soon(run_app, tg)
await self.should_close.wait()
tg.cancel_scope.cancel()
async def _asgi_receive(self) -> Message:
while self._receive_queue.empty():
await anyio.sleep(0)
return self._receive_queue.get()
async def _asgi_send(self, message: Message) -> None:
self._send_queue.put(message)
# wait for cs.cancel to be called before closing streams
await anyio.sleep_forever()
def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
raise WebSocketDisconnect(
message.get("code", 1000), message.get("reason", "")
)
raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
elif message["type"] == "websocket.http.response.start":
status_code: int = message["status"]
headers: list[tuple[bytes, bytes]] = message["headers"]
body: list[bytes] = []
while True:
message = self.receive()
assert message["type"] == "websocket.http.response.body"
body.append(message["body"])
if not message.get("more_body", False):
break
raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
def send(self, message: Message) -> None:
self._receive_queue.put(message)
self.portal.call(self._receive_tx.send, message)
def send_text(self, data: str) -> None:
self.send({"type": "websocket.receive", "text": data})
@@ -171,9 +154,7 @@ class WebSocketTestSession:
def send_bytes(self, data: bytes) -> None:
self.send({"type": "websocket.receive", "bytes": data})
def send_json(
self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text"
) -> None:
def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None:
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
if mode == "text":
self.send({"type": "websocket.receive", "text": text})
@@ -184,10 +165,7 @@ class WebSocketTestSession:
self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
def receive(self) -> Message:
message = self._send_queue.get()
if isinstance(message, BaseException):
raise message
return message
return self.portal.call(self._send_rx.receive)
def receive_text(self) -> str:
message = self.receive()
@@ -199,9 +177,7 @@ class WebSocketTestSession:
self._raise_on_close(message)
return typing.cast(bytes, message["bytes"])
def receive_json(
self, mode: typing.Literal["text", "binary"] = "text"
) -> typing.Any:
def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any:
message = self.receive()
self._raise_on_close(message)
if mode == "text":
@@ -219,6 +195,7 @@ class _TestClientTransport(httpx.BaseTransport):
raise_server_exceptions: bool = True,
root_path: str = "",
*,
client: tuple[str, int],
app_state: dict[str, typing.Any],
) -> None:
self.app = app
@@ -226,6 +203,7 @@ class _TestClientTransport(httpx.BaseTransport):
self.root_path = root_path
self.portal_factory = portal_factory
self.app_state = app_state
self.client = client
def handle_request(self, request: httpx.Request) -> httpx.Response:
scheme = request.url.scheme
@@ -252,10 +230,7 @@ class _TestClientTransport(httpx.BaseTransport):
headers = [(b"host", (f"{host}:{port}").encode())]
# Include other request headers.
headers += [
(key.lower().encode(), value.encode())
for key, value in request.headers.multi_items()
]
headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
scope: dict[str, typing.Any]
@@ -268,15 +243,16 @@ class _TestClientTransport(httpx.BaseTransport):
scope = {
"type": "websocket",
"path": unquote(path),
"raw_path": raw_path,
"raw_path": raw_path.split(b"?", 1)[0],
"root_path": self.root_path,
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": None,
"client": self.client,
"server": [host, port],
"subprotocols": subprotocols,
"state": self.app_state.copy(),
"extensions": {"websocket.http.response": {}},
}
session = WebSocketTestSession(self.app, scope, self.portal_factory)
raise _Upgrade(session)
@@ -286,12 +262,12 @@ class _TestClientTransport(httpx.BaseTransport):
"http_version": "1.1",
"method": request.method,
"path": unquote(path),
"raw_path": raw_path,
"raw_path": raw_path.split(b"?", 1)[0],
"root_path": self.root_path,
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": None,
"client": self.client,
"server": [host, port],
"extensions": {"http.response.debug": {}},
"state": self.app_state.copy(),
@@ -336,22 +312,13 @@ class _TestClientTransport(httpx.BaseTransport):
nonlocal raw_kwargs, response_started, template, context
if message["type"] == "http.response.start":
assert (
not response_started
), 'Received multiple "http.response.start" messages.'
assert not response_started, 'Received multiple "http.response.start" messages.'
raw_kwargs["status_code"] = message["status"]
raw_kwargs["headers"] = [
(key.decode(), value.decode())
for key, value in message.get("headers", [])
]
raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
response_started = True
elif message["type"] == "http.response.body":
assert (
response_started
), 'Received "http.response.body" without "http.response.start".'
assert (
not response_complete.is_set()
), 'Received "http.response.body" after response completed.'
assert response_started, 'Received "http.response.body" without "http.response.start".'
assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
body = message.get("body", b"")
more_body = message.get("more_body", False)
if request.method != "HEAD":
@@ -405,10 +372,9 @@ class TestClient(httpx.Client):
cookies: httpx._types.CookieTypes | None = None,
headers: dict[str, str] | None = None,
follow_redirects: bool = True,
client: tuple[str, int] = ("testclient", 50000),
) -> None:
self.async_backend = _AsyncBackend(
backend=backend, backend_options=backend_options or {}
)
self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
if _is_asgi3(app):
asgi_app = app
else:
@@ -422,12 +388,12 @@ class TestClient(httpx.Client):
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
app_state=self.app_state,
client=client,
)
if headers is None:
headers = {}
headers.setdefault("user-agent", "testclient")
super().__init__(
app=self.app,
base_url=base_url,
headers=headers,
transport=transport,
@@ -440,32 +406,9 @@ class TestClient(httpx.Client):
if self.portal is not None:
yield self.portal
else:
with anyio.from_thread.start_blocking_portal(
**self.async_backend
) as portal:
with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
yield portal
def _choose_redirect_arg(
self, follow_redirects: bool | None, allow_redirects: bool | None
) -> bool | httpx._client.UseClientDefault:
redirect: bool | httpx._client.UseClientDefault = (
httpx._client.USE_CLIENT_DEFAULT
)
if allow_redirects is not None:
message = (
"The `allow_redirects` argument is deprecated. "
"Use `follow_redirects` instead."
)
warnings.warn(message, DeprecationWarning)
redirect = allow_redirects
if follow_redirects is not None:
redirect = follow_redirects
elif allow_redirects is not None and follow_redirects is not None:
raise RuntimeError( # pragma: no cover
"Cannot use both `allow_redirects` and `follow_redirects`."
)
return redirect
def request( # type: ignore[override]
self,
method: str,
@@ -478,16 +421,12 @@ class TestClient(httpx.Client):
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
timeout: httpx._types.TimeoutTypes
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
url = self._merge_url(url)
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().request(
method,
url,
@@ -499,7 +438,7 @@ class TestClient(httpx.Client):
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=redirect,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
@@ -511,22 +450,18 @@ class TestClient(httpx.Client):
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
timeout: httpx._types.TimeoutTypes
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().get(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=redirect,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
@@ -538,22 +473,18 @@ class TestClient(httpx.Client):
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
timeout: httpx._types.TimeoutTypes
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().options(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=redirect,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
@@ -565,22 +496,18 @@ class TestClient(httpx.Client):
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
timeout: httpx._types.TimeoutTypes
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().head(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=redirect,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
@@ -596,15 +523,11 @@ class TestClient(httpx.Client):
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
timeout: httpx._types.TimeoutTypes
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().post(
url,
content=content,
@@ -615,7 +538,7 @@ class TestClient(httpx.Client):
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=redirect,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
@@ -631,15 +554,11 @@ class TestClient(httpx.Client):
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
timeout: httpx._types.TimeoutTypes
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().put(
url,
content=content,
@@ -650,7 +569,7 @@ class TestClient(httpx.Client):
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=redirect,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
@@ -666,15 +585,11 @@ class TestClient(httpx.Client):
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
timeout: httpx._types.TimeoutTypes
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().patch(
url,
content=content,
@@ -685,7 +600,7 @@ class TestClient(httpx.Client):
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=redirect,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
@@ -697,22 +612,18 @@ class TestClient(httpx.Client):
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
timeout: httpx._types.TimeoutTypes
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().delete(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=redirect,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
@@ -742,22 +653,22 @@ class TestClient(httpx.Client):
def __enter__(self) -> TestClient:
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(
anyio.from_thread.start_blocking_portal(**self.async_backend)
)
self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
@stack.callback
def reset_portal() -> None:
self.portal = None
send1: ObjectSendStream[typing.MutableMapping[str, typing.Any] | None]
receive1: ObjectReceiveStream[typing.MutableMapping[str, typing.Any] | None]
send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
send1, receive1 = anyio.create_memory_object_stream(math.inf)
send2, receive2 = anyio.create_memory_object_stream(math.inf)
self.stream_send = StapledObjectStream(send1, receive1)
self.stream_receive = StapledObjectStream(send2, receive2)
send: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any] | None] = (
anyio.create_memory_object_stream(math.inf)
)
receive: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]] = (
anyio.create_memory_object_stream(math.inf)
)
for channel in (*send, *receive):
stack.callback(channel.close)
self.stream_send = StapledObjectStream(*send)
self.stream_receive = StapledObjectStream(*receive)
self.task = portal.start_task_soon(self.lifespan)
portal.call(self.wait_startup)
@@ -803,12 +714,11 @@ class TestClient(httpx.Client):
self.task.result()
return message
async with self.stream_send:
await self.stream_receive.send({"type": "lifespan.shutdown"})
message = await receive()
assert message["type"] in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
if message["type"] == "lifespan.shutdown.failed":
await receive()
await self.stream_receive.send({"type": "lifespan.shutdown"})
message = await receive()
assert message["type"] in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
if message["type"] == "lifespan.shutdown.failed":
await receive()

View File

@@ -16,15 +16,9 @@ Send = typing.Callable[[Message], typing.Awaitable[None]]
ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
StatelessLifespan = typing.Callable[[AppType], typing.AsyncContextManager[None]]
StatefulLifespan = typing.Callable[
[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]
]
StatefulLifespan = typing.Callable[[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]]
Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]
HTTPExceptionHandler = typing.Callable[
["Request", Exception], "Response | typing.Awaitable[Response]"
]
WebSocketExceptionHandler = typing.Callable[
["WebSocket", Exception], typing.Awaitable[None]
]
HTTPExceptionHandler = typing.Callable[["Request", Exception], "Response | typing.Awaitable[Response]"]
WebSocketExceptionHandler = typing.Callable[["WebSocket", Exception], typing.Awaitable[None]]
ExceptionHandler = typing.Union[HTTPExceptionHandler, WebSocketExceptionHandler]

View File

@@ -5,6 +5,7 @@ import json
import typing
from starlette.requests import HTTPConnection
from starlette.responses import Response
from starlette.types import Message, Receive, Scope, Send
@@ -12,10 +13,11 @@ class WebSocketState(enum.Enum):
CONNECTING = 0
CONNECTED = 1
DISCONNECTED = 2
RESPONSE = 3
class WebSocketDisconnect(Exception):
def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None:
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
self.code = code
self.reason = reason or ""
@@ -37,10 +39,7 @@ class WebSocket(HTTPConnection):
message = await self._receive()
message_type = message["type"]
if message_type != "websocket.connect":
raise RuntimeError(
'Expected ASGI message "websocket.connect", '
f"but got {message_type!r}"
)
raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}')
self.client_state = WebSocketState.CONNECTED
return message
elif self.client_state == WebSocketState.CONNECTED:
@@ -48,16 +47,13 @@ class WebSocket(HTTPConnection):
message_type = message["type"]
if message_type not in {"websocket.receive", "websocket.disconnect"}:
raise RuntimeError(
'Expected ASGI message "websocket.receive" or '
f'"websocket.disconnect", but got {message_type!r}'
f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}'
)
if message_type == "websocket.disconnect":
self.client_state = WebSocketState.DISCONNECTED
return message
else:
raise RuntimeError(
'Cannot call "receive" once a disconnect message has been received.'
)
raise RuntimeError('Cannot call "receive" once a disconnect message has been received.')
async def send(self, message: Message) -> None:
"""
@@ -65,13 +61,15 @@ class WebSocket(HTTPConnection):
"""
if self.application_state == WebSocketState.CONNECTING:
message_type = message["type"]
if message_type not in {"websocket.accept", "websocket.close"}:
if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}:
raise RuntimeError(
'Expected ASGI message "websocket.accept" or '
f'"websocket.close", but got {message_type!r}'
'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", '
f"but got {message_type!r}"
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
elif message_type == "websocket.http.response.start":
self.application_state = WebSocketState.RESPONSE
else:
self.application_state = WebSocketState.CONNECTED
await self._send(message)
@@ -79,16 +77,22 @@ class WebSocket(HTTPConnection):
message_type = message["type"]
if message_type not in {"websocket.send", "websocket.close"}:
raise RuntimeError(
'Expected ASGI message "websocket.send" or "websocket.close", '
f"but got {message_type!r}"
f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}'
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
try:
await self._send(message)
except IOError:
except OSError:
self.application_state = WebSocketState.DISCONNECTED
raise WebSocketDisconnect(code=1006)
elif self.application_state == WebSocketState.RESPONSE:
message_type = message["type"]
if message_type != "websocket.http.response.body":
raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}')
if not message.get("more_body", False):
self.application_state = WebSocketState.DISCONNECTED
await self._send(message)
else:
raise RuntimeError('Cannot call "send" once a close message has been sent.')
@@ -99,12 +103,10 @@ class WebSocket(HTTPConnection):
) -> None:
headers = headers or []
if self.client_state == WebSocketState.CONNECTING:
if self.client_state == WebSocketState.CONNECTING: # pragma: no branch
# If we haven't yet seen the 'connect' message, then wait for it first.
await self.receive()
await self.send(
{"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
)
await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers})
def _raise_on_disconnect(self, message: Message) -> None:
if message["type"] == "websocket.disconnect":
@@ -112,18 +114,14 @@ class WebSocket(HTTPConnection):
async def receive_text(self) -> str:
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
return typing.cast(str, message["text"])
async def receive_bytes(self) -> bytes:
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
return typing.cast(bytes, message["bytes"])
@@ -132,9 +130,7 @@ class WebSocket(HTTPConnection):
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
@@ -181,9 +177,13 @@ class WebSocket(HTTPConnection):
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
async def close(self, code: int = 1000, reason: str | None = None) -> None:
await self.send(
{"type": "websocket.close", "code": code, "reason": reason or ""}
)
await self.send({"type": "websocket.close", "code": code, "reason": reason or ""})
async def send_denial_response(self, response: Response) -> None:
if "websocket.http.response" in self.scope.get("extensions", {}):
await response(self.scope, self.receive, self.send)
else:
raise RuntimeError("The server doesn't support the Websocket Denial Response extension.")
class WebSocketClose:
@@ -192,6 +192,4 @@ class WebSocketClose:
self.reason = reason or ""
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await send(
{"type": "websocket.close", "code": self.code, "reason": self.reason}
)
await send({"type": "websocket.close", "code": self.code, "reason": self.reason})