1
0
mirror of https://gitlab.com/MoonTestUse1/AdministrationItDepartmens.git synced 2025-08-14 00:25:46 +02:00

Проверка 09.02.2025

This commit is contained in:
MoonTestUse1
2025-02-09 01:11:49 +06:00
parent ce52f8a23a
commit 0aa3ef8fc2
5827 changed files with 14316 additions and 1906434 deletions

View File

@@ -1,5 +1,5 @@
from uvicorn.config import Config
from uvicorn.main import Server, main, run
__version__ = "0.27.1"
__version__ = "0.34.0"
__all__ = ["main", "run", "Config", "Server"]

View File

@@ -2,6 +2,7 @@
Some light wrappers around Python's multiprocessing, to deal with cleanly
starting child processes.
"""
from __future__ import annotations
import multiprocessing
@@ -9,7 +10,7 @@ import os
import sys
from multiprocessing.context import SpawnProcess
from socket import socket
from typing import Callable, Optional
from typing import Callable
from uvicorn.config import Config
@@ -34,10 +35,10 @@ def get_subprocess(
"""
# We pass across the stdin fileno, and reopen it in the child process.
# This is required for some debugging environments.
stdin_fileno: Optional[int]
try:
stdin_fileno = sys.stdin.fileno()
except OSError:
# The `sys.stdin` can be `None`, see https://docs.python.org/3/library/sys.html#sys.__stdin__.
except (AttributeError, OSError):
stdin_fileno = None
kwargs = {
@@ -69,10 +70,15 @@ def subprocess_started(
"""
# Re-open stdin.
if stdin_fileno is not None:
sys.stdin = os.fdopen(stdin_fileno)
sys.stdin = os.fdopen(stdin_fileno) # pragma: full coverage
# Logging needs to be setup again for each child.
config.configure_logging()
# Now we can call into `Server.run(sockets=sockets)`
target(sockets=sockets)
try:
# Now we can call into `Server.run(sockets=sockets)`
target(sockets=sockets)
except KeyboardInterrupt: # pragma: no cover
# supress the exception to avoid a traceback from subprocess.Popen
# the parent already expects us to end, so no vital information is lost
pass

View File

@@ -27,26 +27,13 @@ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
from __future__ import annotations
import sys
import types
from typing import (
Any,
Awaitable,
Callable,
Iterable,
MutableMapping,
Optional,
Tuple,
Type,
Union,
)
if sys.version_info >= (3, 8): # pragma: py-lt-38
from typing import Literal, Protocol, TypedDict
else: # pragma: py-gte-38
from typing_extensions import Literal, Protocol, TypedDict
from collections.abc import Awaitable, Iterable, MutableMapping
from typing import Any, Callable, Literal, Optional, Protocol, TypedDict, Union
if sys.version_info >= (3, 11): # pragma: py-lt-311
from typing import NotRequired
@@ -55,8 +42,8 @@ else: # pragma: py-gte-311
# WSGI
Environ = MutableMapping[str, Any]
ExcInfo = Tuple[Type[BaseException], BaseException, Optional[types.TracebackType]]
StartResponse = Callable[[str, Iterable[Tuple[str, str]], Optional[ExcInfo]], None]
ExcInfo = tuple[type[BaseException], BaseException, Optional[types.TracebackType]]
StartResponse = Callable[[str, Iterable[tuple[str, str]], Optional[ExcInfo]], None]
WSGIApp = Callable[[Environ, StartResponse], Union[Iterable[bytes], BaseException]]
@@ -205,6 +192,7 @@ class WebSocketResponseBodyEvent(TypedDict):
class WebSocketDisconnectEvent(TypedDict):
type: Literal["websocket.disconnect"]
code: int
reason: NotRequired[str | None]
class WebSocketCloseEvent(TypedDict):
@@ -239,9 +227,7 @@ class LifespanShutdownFailedEvent(TypedDict):
message: str
WebSocketEvent = Union[
WebSocketReceiveEvent, WebSocketDisconnectEvent, WebSocketConnectEvent
]
WebSocketEvent = Union[WebSocketReceiveEvent, WebSocketDisconnectEvent, WebSocketConnectEvent]
ASGIReceiveEvent = Union[
@@ -278,16 +264,12 @@ ASGISendCallable = Callable[[ASGISendEvent], Awaitable[None]]
class ASGI2Protocol(Protocol):
def __init__(self, scope: Scope) -> None:
... # pragma: no cover
def __init__(self, scope: Scope) -> None: ... # pragma: no cover
async def __call__(
self, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
... # pragma: no cover
async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: ... # pragma: no cover
ASGI2Application = Type[ASGI2Protocol]
ASGI2Application = type[ASGI2Protocol]
ASGI3Application = Callable[
[
Scope,

View File

@@ -9,8 +9,10 @@ import os
import socket
import ssl
import sys
from collections.abc import Awaitable
from configparser import RawConfigParser
from pathlib import Path
from typing import Any, Awaitable, Callable, Literal
from typing import IO, Any, Callable, Literal
import click
@@ -123,13 +125,11 @@ def is_dir(path: Path) -> bool:
if not path.is_absolute():
path = path.resolve()
return path.is_dir()
except OSError:
except OSError: # pragma: full coverage
return False
def resolve_reload_patterns(
patterns_list: list[str], directories_list: list[str]
) -> tuple[list[str], list[Path]]:
def resolve_reload_patterns(patterns_list: list[str], directories_list: list[str]) -> tuple[list[str], list[Path]]:
directories: list[Path] = list(set(map(Path, directories_list.copy())))
patterns: list[str] = patterns_list.copy()
@@ -138,7 +138,7 @@ def resolve_reload_patterns(
# Special case for the .* pattern, otherwise this would only match
# hidden directories which is probably undesired
if pattern == ".*":
continue
continue # pragma: py-darwin
patterns.append(pattern)
if is_dir(Path(pattern)):
directories.append(Path(pattern))
@@ -150,15 +150,13 @@ def resolve_reload_patterns(
directories = list(set(directories))
directories = list(map(Path, directories))
directories = list(map(lambda x: x.resolve(), directories))
directories = list(
{reload_path for reload_path in directories if is_dir(reload_path)}
)
directories = list({reload_path for reload_path in directories if is_dir(reload_path)})
children = []
for j in range(len(directories)):
for k in range(j + 1, len(directories)):
for k in range(j + 1, len(directories)): # pragma: full coverage
if directories[j] in directories[k].parents:
children.append(directories[k]) # pragma: py-darwin
children.append(directories[k])
elif directories[k] in directories[j].parents:
children.append(directories[j])
@@ -193,7 +191,7 @@ class Config:
ws_per_message_deflate: bool = True,
lifespan: LifespanType = "auto",
env_file: str | os.PathLike[str] | None = None,
log_config: dict[str, Any] | str | None = LOGGING_CONFIG,
log_config: dict[str, Any] | str | RawConfigParser | IO[Any] | None = LOGGING_CONFIG,
log_level: str | int | None = None,
access_log: bool = True,
use_colors: bool | None = None,
@@ -216,7 +214,7 @@ class Config:
timeout_notify: int = 30,
timeout_graceful_shutdown: int | None = None,
callback_notify: Callable[..., Awaitable[None]] | None = None,
ssl_keyfile: str | None = None,
ssl_keyfile: str | os.PathLike[str] | None = None,
ssl_certfile: str | os.PathLike[str] | None = None,
ssl_keyfile_password: str | None = None,
ssl_version: int = SSL_PROTOCOL_VERSION,
@@ -280,12 +278,9 @@ class Config:
self.reload_includes: list[str] = []
self.reload_excludes: list[str] = []
if (
reload_dirs or reload_includes or reload_excludes
) and not self.should_reload:
if (reload_dirs or reload_includes or reload_excludes) and not self.should_reload:
logger.warning(
"Current configuration will not reload as not all conditions are met, "
"please refer to documentation."
"Current configuration will not reload as not all conditions are met, " "please refer to documentation."
)
if self.should_reload:
@@ -293,30 +288,23 @@ class Config:
reload_includes = _normalize_dirs(reload_includes)
reload_excludes = _normalize_dirs(reload_excludes)
self.reload_includes, self.reload_dirs = resolve_reload_patterns(
reload_includes, reload_dirs
)
self.reload_includes, self.reload_dirs = resolve_reload_patterns(reload_includes, reload_dirs)
self.reload_excludes, self.reload_dirs_excludes = resolve_reload_patterns(
reload_excludes, []
)
self.reload_excludes, self.reload_dirs_excludes = resolve_reload_patterns(reload_excludes, [])
reload_dirs_tmp = self.reload_dirs.copy()
for directory in self.reload_dirs_excludes:
for reload_directory in reload_dirs_tmp:
if (
directory == reload_directory
or directory in reload_directory.parents
):
if directory == reload_directory or directory in reload_directory.parents:
try:
self.reload_dirs.remove(reload_directory)
except ValueError:
except ValueError: # pragma: full coverage
pass
for pattern in self.reload_excludes:
if pattern in self.reload_includes:
self.reload_includes.remove(pattern)
self.reload_includes.remove(pattern) # pragma: full coverage
if not self.reload_dirs:
if reload_dirs:
@@ -343,11 +331,9 @@ class Config:
self.forwarded_allow_ips: list[str] | str
if forwarded_allow_ips is None:
self.forwarded_allow_ips = os.environ.get(
"FORWARDED_ALLOW_IPS", "127.0.0.1"
)
self.forwarded_allow_ips = os.environ.get("FORWARDED_ALLOW_IPS", "127.0.0.1")
else:
self.forwarded_allow_ips = forwarded_allow_ips
self.forwarded_allow_ips = forwarded_allow_ips # pragma: full coverage
if self.reload and self.workers > 1:
logger.warning('"workers" flag is ignored when reloading is enabled.')
@@ -375,18 +361,14 @@ class Config:
if self.log_config is not None:
if isinstance(self.log_config, dict):
if self.use_colors in (True, False):
self.log_config["formatters"]["default"][
"use_colors"
] = self.use_colors
self.log_config["formatters"]["access"][
"use_colors"
] = self.use_colors
self.log_config["formatters"]["default"]["use_colors"] = self.use_colors
self.log_config["formatters"]["access"]["use_colors"] = self.use_colors
logging.config.dictConfig(self.log_config)
elif self.log_config.endswith(".json"):
elif isinstance(self.log_config, str) and self.log_config.endswith(".json"):
with open(self.log_config) as file:
loaded_config = json.load(file)
logging.config.dictConfig(loaded_config)
elif self.log_config.endswith((".yaml", ".yml")):
elif isinstance(self.log_config, str) and self.log_config.endswith((".yaml", ".yml")):
# Install the PyYAML package or the uvicorn[standard] optional
# dependencies to enable this functionality.
import yaml
@@ -397,9 +379,7 @@ class Config:
else:
# See the note about fileConfig() here:
# https://docs.python.org/3/library/logging.config.html#configuration-file-format
logging.config.fileConfig(
self.log_config, disable_existing_loggers=False
)
logging.config.fileConfig(self.log_config, disable_existing_loggers=False)
if self.log_level is not None:
if isinstance(self.log_level, str):
@@ -430,10 +410,7 @@ class Config:
else:
self.ssl = None
encoded_headers = [
(key.lower().encode("latin1"), value.encode("latin1"))
for key, value in self.headers
]
encoded_headers = [(key.lower().encode("latin1"), value.encode("latin1")) for key, value in self.headers]
self.encoded_headers = (
[(b"server", b"uvicorn")] + encoded_headers
if b"server" not in dict(encoded_headers) and self.server_header
@@ -469,8 +446,7 @@ class Config:
else:
if not self.factory:
logger.warning(
"ASGI app factory detected. Using it, "
"but please consider setting the --factory flag explicitly."
"ASGI app factory detected. Using it, " "but please consider setting the --factory flag explicitly."
)
if self.interface == "auto":
@@ -492,9 +468,7 @@ class Config:
if logger.getEffectiveLevel() <= TRACE_LOG_LEVEL:
self.loaded_app = MessageLoggerMiddleware(self.loaded_app)
if self.proxy_headers:
self.loaded_app = ProxyHeadersMiddleware(
self.loaded_app, trusted_hosts=self.forwarded_allow_ips
)
self.loaded_app = ProxyHeadersMiddleware(self.loaded_app, trusted_hosts=self.forwarded_allow_ips)
self.loaded = True
@@ -512,33 +486,25 @@ class Config:
sock.bind(path)
uds_perms = 0o666
os.chmod(self.uds, uds_perms)
except OSError as exc:
except OSError as exc: # pragma: full coverage
logger.error(exc)
sys.exit(1)
message = "Uvicorn running on unix socket %s (Press CTRL+C to quit)"
sock_name_format = "%s"
color_message = (
"Uvicorn running on "
+ click.style(sock_name_format, bold=True)
+ " (Press CTRL+C to quit)"
)
color_message = "Uvicorn running on " + click.style(sock_name_format, bold=True) + " (Press CTRL+C to quit)"
logger_args = [self.uds]
elif self.fd: # pragma: py-win32
sock = socket.fromfd(self.fd, socket.AF_UNIX, socket.SOCK_STREAM)
message = "Uvicorn running on socket %s (Press CTRL+C to quit)"
fd_name_format = "%s"
color_message = (
"Uvicorn running on "
+ click.style(fd_name_format, bold=True)
+ " (Press CTRL+C to quit)"
)
color_message = "Uvicorn running on " + click.style(fd_name_format, bold=True) + " (Press CTRL+C to quit)"
logger_args = [sock.getsockname()]
else:
family = socket.AF_INET
addr_format = "%s://%s:%d"
if self.host and ":" in self.host: # pragma: py-win32
if self.host and ":" in self.host: # pragma: full coverage
# It's an IPv6 address.
family = socket.AF_INET6
addr_format = "%s://[%s]:%d"
@@ -547,16 +513,12 @@ class Config:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
sock.bind((self.host, self.port))
except OSError as exc:
except OSError as exc: # pragma: full coverage
logger.error(exc)
sys.exit(1)
message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)"
color_message = (
"Uvicorn running on "
+ click.style(addr_format, bold=True)
+ " (Press CTRL+C to quit)"
)
color_message = "Uvicorn running on " + click.style(addr_format, bold=True) + " (Press CTRL+C to quit)"
protocol_name = "https" if self.is_ssl else "http"
logger_args = [protocol_name, self.host, sock.getsockname()[1]]
logger.info(message, *logger_args, extra={"color_message": color_message})

View File

@@ -12,9 +12,7 @@ def import_from_string(import_str: Any) -> Any:
module_str, _, attrs_str = import_str.partition(":")
if not module_str or not attrs_str:
message = (
'Import string "{import_str}" must be in format "<module>:<attribute>".'
)
message = 'Import string "{import_str}" must be in format "<module>:<attribute>".'
raise ImportFromStringError(message.format(import_str=import_str))
try:
@@ -31,8 +29,6 @@ def import_from_string(import_str: Any) -> Any:
instance = getattr(instance, attr_str)
except AttributeError:
message = 'Attribute "{attrs_str}" not found in module "{module_str}".'
raise ImportFromStringError(
message.format(attrs_str=attrs_str, module_str=module_str)
)
raise ImportFromStringError(message.format(attrs_str=attrs_str, module_str=module_str))
return instance

View File

@@ -1,4 +1,6 @@
from typing import Any, Dict
from __future__ import annotations
from typing import Any
from uvicorn import Config
@@ -6,7 +8,7 @@ from uvicorn import Config
class LifespanOff:
def __init__(self, config: Config) -> None:
self.should_exit = False
self.state: Dict[str, Any] = {}
self.state: dict[str, Any] = {}
async def startup(self) -> None:
pass

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import asyncio
import logging
from asyncio import Queue
from typing import Any, Dict, Union
from typing import Any, Union
from uvicorn import Config
from uvicorn._types import (
@@ -35,12 +37,12 @@ class LifespanOn:
self.logger = logging.getLogger("uvicorn.error")
self.startup_event = asyncio.Event()
self.shutdown_event = asyncio.Event()
self.receive_queue: "Queue[LifespanReceiveMessage]" = asyncio.Queue()
self.receive_queue: Queue[LifespanReceiveMessage] = asyncio.Queue()
self.error_occured = False
self.startup_failed = False
self.shutdown_failed = False
self.should_exit = False
self.state: Dict[str, Any] = {}
self.state: dict[str, Any] = {}
async def startup(self) -> None:
self.logger.info("Waiting for application startup.")
@@ -67,9 +69,7 @@ class LifespanOn:
await self.receive_queue.put(shutdown_event)
await self.shutdown_event.wait()
if self.shutdown_failed or (
self.error_occured and self.config.lifespan == "on"
):
if self.shutdown_failed or (self.error_occured and self.config.lifespan == "on"):
self.logger.error("Application shutdown failed. Exiting.")
self.should_exit = True
else:
@@ -99,7 +99,7 @@ class LifespanOn:
self.startup_event.set()
self.shutdown_event.set()
async def send(self, message: "LifespanSendMessage") -> None:
async def send(self, message: LifespanSendMessage) -> None:
assert message["type"] in (
"lifespan.startup.complete",
"lifespan.startup.failed",
@@ -133,5 +133,5 @@ class LifespanOn:
if message.get("message"):
self.logger.error(message["message"])
async def receive(self) -> "LifespanReceiveMessage":
async def receive(self) -> LifespanReceiveMessage:
return await self.receive_queue.get()

View File

@@ -16,7 +16,7 @@ class ColourizedFormatter(logging.Formatter):
A custom log formatter class that:
* Outputs the LOG_LEVEL with an appropriate color.
* If a log call includes an `extras={"color_message": ...}` it will be used
* If a log call includes an `extra={"color_message": ...}` it will be used
for formatting the output, instead of the plain text message.
"""
@@ -26,9 +26,7 @@ class ColourizedFormatter(logging.Formatter):
logging.INFO: lambda level_name: click.style(str(level_name), fg="green"),
logging.WARNING: lambda level_name: click.style(str(level_name), fg="yellow"),
logging.ERROR: lambda level_name: click.style(str(level_name), fg="red"),
logging.CRITICAL: lambda level_name: click.style(
str(level_name), fg="bright_red"
),
logging.CRITICAL: lambda level_name: click.style(str(level_name), fg="bright_red"),
}
def __init__(
@@ -86,7 +84,7 @@ class AccessFormatter(ColourizedFormatter):
status_phrase = http.HTTPStatus(status_code).phrase
except ValueError:
status_phrase = ""
status_and_phrase = "%s %s" % (status_code, status_phrase)
status_and_phrase = f"{status_code} {status_phrase}"
if self.use_colors:
def default(code: int) -> str:
@@ -106,7 +104,7 @@ class AccessFormatter(ColourizedFormatter):
status_code,
) = recordcopy.args # type: ignore[misc]
status_code = self.get_status_code(int(status_code)) # type: ignore[arg-type]
request_line = "%s %s HTTP/%s" % (method, full_path, http_version)
request_line = f"{method} {full_path} HTTP/{http_version}"
if self.use_colors:
request_line = click.style(request_line, bold=True)
recordcopy.__dict__.update(

View File

@@ -7,4 +7,4 @@ logger = logging.getLogger("uvicorn.error")
def asyncio_setup(use_subprocess: bool = False) -> None:
if sys.platform == "win32" and use_subprocess:
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # pragma: full coverage

View File

@@ -6,7 +6,8 @@ import os
import platform
import ssl
import sys
from typing import Any, Callable
from configparser import RawConfigParser
from typing import IO, Any, Callable
import click
@@ -47,12 +48,11 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
if not value or ctx.resilient_parsing:
return
click.echo(
"Running uvicorn %s with %s %s on %s"
% (
uvicorn.__version__,
platform.python_implementation(),
platform.python_version(),
platform.system(),
"Running uvicorn {version} with {py_implementation} {py_version} on {system}".format( # noqa: UP032
version=uvicorn.__version__,
py_implementation=platform.python_implementation(),
py_version=platform.python_version(),
system=platform.system(),
)
)
ctx.exit()
@@ -75,16 +75,13 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
show_default=True,
)
@click.option("--uds", type=str, default=None, help="Bind to a UNIX domain socket.")
@click.option(
"--fd", type=int, default=None, help="Bind to socket from this file descriptor."
)
@click.option("--fd", type=int, default=None, help="Bind to socket from this file descriptor.")
@click.option("--reload", is_flag=True, default=False, help="Enable auto-reload.")
@click.option(
"--reload-dir",
"reload_dirs",
multiple=True,
help="Set reload directories explicitly, instead of using the current working"
" directory.",
help="Set reload directories explicitly, instead of using the current working" " directory.",
type=click.Path(exists=True),
)
@click.option(
@@ -109,8 +106,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
type=float,
default=0.25,
show_default=True,
help="Delay between previous and next check if application needs to be."
" Defaults to 0.25s.",
help="Delay between previous and next check if application needs to be." " Defaults to 0.25s.",
)
@click.option(
"--workers",
@@ -226,8 +222,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
"--proxy-headers/--no-proxy-headers",
is_flag=True,
default=True,
help="Enable/Disable X-Forwarded-Proto, X-Forwarded-For, X-Forwarded-Port to "
"populate remote address info.",
help="Enable/Disable X-Forwarded-Proto, X-Forwarded-For, X-Forwarded-Port to " "populate remote address info.",
)
@click.option(
"--server-header/--no-server-header",
@@ -245,8 +240,10 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
"--forwarded-allow-ips",
type=str,
default=None,
help="Comma separated list of IPs to trust with proxy headers. Defaults to"
" the $FORWARDED_ALLOW_IPS environment variable if available, or '127.0.0.1'.",
help="Comma separated list of IP Addresses, IP Networks, or literals "
"(e.g. UNIX Socket path) to trust with proxy headers. Defaults to the "
"$FORWARDED_ALLOW_IPS environment variable if available, or '127.0.0.1'. "
"The literal '*' means trust everything.",
)
@click.option(
"--root-path",
@@ -258,8 +255,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
"--limit-concurrency",
type=int,
default=None,
help="Maximum number of concurrent connections or tasks to allow, before issuing"
" HTTP 503 responses.",
help="Maximum number of concurrent connections or tasks to allow, before issuing" " HTTP 503 responses.",
)
@click.option(
"--backlog",
@@ -286,9 +282,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
default=None,
help="Maximum number of seconds to wait for graceful shutdown.",
)
@click.option(
"--ssl-keyfile", type=str, default=None, help="SSL key file", show_default=True
)
@click.option("--ssl-keyfile", type=str, default=None, help="SSL key file", show_default=True)
@click.option(
"--ssl-certfile",
type=str,
@@ -490,7 +484,7 @@ def run(
reload_delay: float = 0.25,
workers: int | None = None,
env_file: str | os.PathLike[str] | None = None,
log_config: dict[str, Any] | str | None = LOGGING_CONFIG,
log_config: dict[str, Any] | str | RawConfigParser | IO[Any] | None = LOGGING_CONFIG,
log_level: str | int | None = None,
access_log: bool = True,
proxy_headers: bool = True,
@@ -503,7 +497,7 @@ def run(
limit_max_requests: int | None = None,
timeout_keep_alive: int = 5,
timeout_graceful_shutdown: int | None = None,
ssl_keyfile: str | None = None,
ssl_keyfile: str | os.PathLike[str] | None = None,
ssl_certfile: str | os.PathLike[str] | None = None,
ssl_keyfile_password: str | None = None,
ssl_version: int = SSL_PROTOCOL_VERSION,
@@ -571,22 +565,23 @@ def run(
if (config.reload or config.workers > 1) and not isinstance(app, str):
logger = logging.getLogger("uvicorn.error")
logger.warning(
"You must pass the application as an import string to enable 'reload' or "
"'workers'."
)
logger.warning("You must pass the application as an import string to enable 'reload' or " "'workers'.")
sys.exit(1)
if config.should_reload:
sock = config.bind_socket()
ChangeReload(config, target=server.run, sockets=[sock]).run()
elif config.workers > 1:
sock = config.bind_socket()
Multiprocess(config, target=server.run, sockets=[sock]).run()
else:
server.run()
if config.uds and os.path.exists(config.uds):
os.remove(config.uds) # pragma: py-win32
try:
if config.should_reload:
sock = config.bind_socket()
ChangeReload(config, target=server.run, sockets=[sock]).run()
elif config.workers > 1:
sock = config.bind_socket()
Multiprocess(config, target=server.run, sockets=[sock]).run()
else:
server.run()
except KeyboardInterrupt:
pass # pragma: full coverage
finally:
if config.uds and os.path.exists(config.uds):
os.remove(config.uds) # pragma: py-win32
if not server.started and not config.should_reload and config.workers == 1:
sys.exit(STARTUP_FAILURE)

View File

@@ -10,8 +10,6 @@ class ASGI2Middleware:
def __init__(self, app: "ASGI2Application"):
self.app = app
async def __call__(
self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
) -> None:
async def __call__(self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable") -> None:
instance = self.app(scope)
await instance(receive, send)

View File

@@ -1,84 +1,142 @@
"""
This middleware can be used when a known proxy is fronting the application,
and is trusted to be properly setting the `X-Forwarded-Proto` and
`X-Forwarded-For` headers with the connecting client information.
from __future__ import annotations
Modifies the `client` and `scheme` information so that they reference
the connecting client, rather that the connecting proxy.
import ipaddress
https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#Proxies
"""
from typing import List, Optional, Tuple, Union, cast
from uvicorn._types import (
ASGI3Application,
ASGIReceiveCallable,
ASGISendCallable,
HTTPScope,
Scope,
WebSocketScope,
)
from uvicorn._types import ASGI3Application, ASGIReceiveCallable, ASGISendCallable, Scope
class ProxyHeadersMiddleware:
def __init__(
self,
app: "ASGI3Application",
trusted_hosts: Union[List[str], str] = "127.0.0.1",
) -> None:
"""Middleware for handling known proxy headers
This middleware can be used when a known proxy is fronting the application,
and is trusted to be properly setting the `X-Forwarded-Proto` and
`X-Forwarded-For` headers with the connecting client information.
Modifies the `client` and `scheme` information so that they reference
the connecting client, rather that the connecting proxy.
References:
- <https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#Proxies>
- <https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For>
"""
def __init__(self, app: ASGI3Application, trusted_hosts: list[str] | str = "127.0.0.1") -> None:
self.app = app
if isinstance(trusted_hosts, str):
self.trusted_hosts = {item.strip() for item in trusted_hosts.split(",")}
else:
self.trusted_hosts = set(trusted_hosts)
self.always_trust = "*" in self.trusted_hosts
self.trusted_hosts = _TrustedHosts(trusted_hosts)
def get_trusted_client_host(
self, x_forwarded_for_hosts: List[str]
) -> Optional[str]:
if self.always_trust:
return x_forwarded_for_hosts[0]
async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
if scope["type"] == "lifespan":
return await self.app(scope, receive, send)
for host in reversed(x_forwarded_for_hosts):
if host not in self.trusted_hosts:
return host
client_addr = scope.get("client")
client_host = client_addr[0] if client_addr else None
return None
if client_host in self.trusted_hosts:
headers = dict(scope["headers"])
async def __call__(
self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
) -> None:
if scope["type"] in ("http", "websocket"):
scope = cast(Union["HTTPScope", "WebSocketScope"], scope)
client_addr: Optional[Tuple[str, int]] = scope.get("client")
client_host = client_addr[0] if client_addr else None
if b"x-forwarded-proto" in headers:
x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1").strip()
if self.always_trust or client_host in self.trusted_hosts:
headers = dict(scope["headers"])
if b"x-forwarded-proto" in headers:
# Determine if the incoming request was http or https based on
# the X-Forwarded-Proto header.
x_forwarded_proto = (
headers[b"x-forwarded-proto"].decode("latin1").strip()
)
if x_forwarded_proto in {"http", "https", "ws", "wss"}:
if scope["type"] == "websocket":
scope["scheme"] = (
"wss" if x_forwarded_proto == "https" else "ws"
)
scope["scheme"] = x_forwarded_proto.replace("http", "ws")
else:
scope["scheme"] = x_forwarded_proto
if b"x-forwarded-for" in headers:
# Determine the client address from the last trusted IP in the
# X-Forwarded-For header. We've lost the connecting client's port
# information by now, so only include the host.
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
x_forwarded_for_hosts = [
item.strip() for item in x_forwarded_for.split(",")
]
host = self.get_trusted_client_host(x_forwarded_for_hosts)
if b"x-forwarded-for" in headers:
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
host = self.trusted_hosts.get_trusted_client_host(x_forwarded_for)
if host:
# If the x-forwarded-for header is empty then host is an empty string.
# Only set the client if we actually got something usable.
# See: https://github.com/encode/uvicorn/issues/1068
# We've lost the connecting client's port information by now,
# so only include the host.
port = 0
scope["client"] = (host, port) # type: ignore[arg-type]
scope["client"] = (host, port)
return await self.app(scope, receive, send)
def _parse_raw_hosts(value: str) -> list[str]:
return [item.strip() for item in value.split(",")]
class _TrustedHosts:
"""Container for trusted hosts and networks"""
def __init__(self, trusted_hosts: list[str] | str) -> None:
self.always_trust: bool = trusted_hosts in ("*", ["*"])
self.trusted_literals: set[str] = set()
self.trusted_hosts: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
self.trusted_networks: set[ipaddress.IPv4Network | ipaddress.IPv6Network] = set()
# Notes:
# - We separate hosts from literals as there are many ways to write
# an IPv6 Address so we need to compare by object.
# - We don't convert IP Address to single host networks (e.g. /32 / 128) as
# it more efficient to do an address lookup in a set than check for
# membership in each network.
# - We still allow literals as it might be possible that we receive a
# something that isn't an IP Address e.g. a unix socket.
if not self.always_trust:
if isinstance(trusted_hosts, str):
trusted_hosts = _parse_raw_hosts(trusted_hosts)
for host in trusted_hosts:
# Note: because we always convert invalid IP types to literals it
# is not possible for the user to know they provided a malformed IP
# type - this may lead to unexpected / difficult to debug behaviour.
if "/" in host:
# Looks like a network
try:
self.trusted_networks.add(ipaddress.ip_network(host))
except ValueError:
# Was not a valid IP Network
self.trusted_literals.add(host)
else:
try:
self.trusted_hosts.add(ipaddress.ip_address(host))
except ValueError:
# Was not a valid IP Address
self.trusted_literals.add(host)
def __contains__(self, host: str | None) -> bool:
if self.always_trust:
return True
if not host:
return False
try:
ip = ipaddress.ip_address(host)
if ip in self.trusted_hosts:
return True
return any(ip in net for net in self.trusted_networks)
except ValueError:
return host in self.trusted_literals
def get_trusted_client_host(self, x_forwarded_for: str) -> str:
"""Extract the client host from x_forwarded_for header
In general this is the first "untrusted" host in the forwarded for list.
"""
x_forwarded_for_hosts = _parse_raw_hosts(x_forwarded_for)
if self.always_trust:
return x_forwarded_for_hosts[0]
# Note: each proxy appends to the header list so check it in reverse order
for host in reversed(x_forwarded_for_hosts):
if host not in self:
return host
# All hosts are trusted meaning that the client was also a trusted proxy
# See https://github.com/encode/uvicorn/issues/1068#issuecomment-855371576
return x_forwarded_for_hosts[0]

View File

@@ -1,10 +1,12 @@
from __future__ import annotations
import asyncio
import concurrent.futures
import io
import sys
import warnings
from collections import deque
from typing import Deque, Iterable, Optional, Tuple
from collections.abc import Iterable
from uvicorn._types import (
ASGIReceiveCallable,
@@ -22,9 +24,7 @@ from uvicorn._types import (
)
def build_environ(
scope: "HTTPScope", message: "ASGIReceiveEvent", body: io.BytesIO
) -> Environ:
def build_environ(scope: HTTPScope, message: ASGIReceiveEvent, body: io.BytesIO) -> Environ:
"""
Builds a scope and request message into a WSGI environ object.
"""
@@ -91,9 +91,9 @@ class _WSGIMiddleware:
async def __call__(
self,
scope: "HTTPScope",
receive: "ASGIReceiveCallable",
send: "ASGISendCallable",
scope: HTTPScope,
receive: ASGIReceiveCallable,
send: ASGISendCallable,
) -> None:
assert scope["type"] == "http"
instance = WSGIResponder(self.app, self.executor, scope)
@@ -105,7 +105,7 @@ class WSGIResponder:
self,
app: WSGIApp,
executor: concurrent.futures.ThreadPoolExecutor,
scope: "HTTPScope",
scope: HTTPScope,
):
self.app = app
self.executor = executor
@@ -113,21 +113,19 @@ class WSGIResponder:
self.status = None
self.response_headers = None
self.send_event = asyncio.Event()
self.send_queue: Deque[Optional["ASGISendEvent"]] = deque()
self.send_queue: deque[ASGISendEvent | None] = deque()
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
self.response_started = False
self.exc_info: Optional[ExcInfo] = None
self.exc_info: ExcInfo | None = None
async def __call__(
self, receive: "ASGIReceiveCallable", send: "ASGISendCallable"
) -> None:
async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
message: HTTPRequestEvent = await receive() # type: ignore[assignment]
body = io.BytesIO(message.get("body", b""))
more_body = message.get("more_body", False)
if more_body:
body.seek(0, io.SEEK_END)
while more_body:
body_message: "HTTPRequestEvent" = (
body_message: HTTPRequestEvent = (
await receive() # type: ignore[assignment]
)
body.write(body_message.get("body", b""))
@@ -135,9 +133,7 @@ class WSGIResponder:
body.seek(0)
environ = build_environ(self.scope, message, body)
self.loop = asyncio.get_event_loop()
wsgi = self.loop.run_in_executor(
self.executor, self.wsgi, environ, self.start_response
)
wsgi = self.loop.run_in_executor(self.executor, self.wsgi, environ, self.start_response)
sender = self.loop.create_task(self.sender(send))
try:
await asyncio.wait_for(wsgi, None)
@@ -148,7 +144,7 @@ class WSGIResponder:
if self.exc_info is not None:
raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
async def sender(self, send: "ASGISendCallable") -> None:
async def sender(self, send: ASGISendCallable) -> None:
while True:
if self.send_queue:
message = self.send_queue.popleft()
@@ -162,18 +158,15 @@ class WSGIResponder:
def start_response(
self,
status: str,
response_headers: Iterable[Tuple[str, str]],
exc_info: Optional[ExcInfo] = None,
response_headers: Iterable[tuple[str, str]],
exc_info: ExcInfo | None = None,
) -> None:
self.exc_info = exc_info
if not self.response_started:
self.response_started = True
status_code_str, _ = status.split(" ", 1)
status_code = int(status_code_str)
headers = [
(name.encode("ascii"), value.encode("ascii"))
for name, value in response_headers
]
headers = [(name.encode("ascii"), value.encode("ascii")) for name, value in response_headers]
http_response_start_event: HTTPResponseStartEvent = {
"type": "http.response.start",
"status": status_code,

View File

@@ -1,12 +1,6 @@
import asyncio
from uvicorn._types import (
ASGIReceiveCallable,
ASGISendCallable,
HTTPResponseBodyEvent,
HTTPResponseStartEvent,
Scope,
)
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
CLOSE_HEADER = (b"connection", b"close")
@@ -22,7 +16,7 @@ class FlowControl:
self._is_writable_event.set()
async def drain(self) -> None:
await self._is_writable_event.wait()
await self._is_writable_event.wait() # pragma: full coverage
def pause_reading(self) -> None:
if not self.read_paused:
@@ -35,32 +29,26 @@ class FlowControl:
self._transport.resume_reading()
def pause_writing(self) -> None:
if not self.write_paused:
if not self.write_paused: # pragma: full coverage
self.write_paused = True
self._is_writable_event.clear()
def resume_writing(self) -> None:
if self.write_paused:
if self.write_paused: # pragma: full coverage
self.write_paused = False
self._is_writable_event.set()
async def service_unavailable(
scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
) -> None:
response_start: "HTTPResponseStartEvent" = {
"type": "http.response.start",
"status": 503,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
],
}
await send(response_start)
response_body: "HTTPResponseBodyEvent" = {
"type": "http.response.body",
"body": b"Service Unavailable",
"more_body": False,
}
await send(response_body)
async def service_unavailable(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
await send(
{
"type": "http.response.start",
"status": 503,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"content-length", b"19"),
(b"connection", b"close"),
],
}
)
await send({"type": "http.response.body", "body": b"Service Unavailable", "more_body": False})

View File

@@ -20,19 +20,8 @@ from uvicorn._types import (
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.http.flow_control import (
CLOSE_HEADER,
HIGH_WATER_LIMIT,
FlowControl,
service_unavailable,
)
from uvicorn.protocols.utils import (
get_client_addr,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
is_ssl,
)
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
from uvicorn.server import ServerState
@@ -43,9 +32,7 @@ def _get_status_phrase(status_code: int) -> bytes:
return b""
STATUS_PHRASES = {
status_code: _get_status_phrase(status_code) for status_code in range(100, 600)
}
STATUS_PHRASES = {status_code: _get_status_phrase(status_code) for status_code in range(100, 600)}
class H11Protocol(asyncio.Protocol):
@@ -160,14 +147,24 @@ class H11Protocol(asyncio.Protocol):
def _should_upgrade_to_ws(self) -> bool:
if self.ws_protocol_class is None:
if self.config.ws == "auto":
msg = "Unsupported upgrade request."
self.logger.warning(msg)
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
return False
return True
def _unsupported_upgrade_warning(self) -> None:
msg = "Unsupported upgrade request."
self.logger.warning(msg)
if not self._should_upgrade_to_ws():
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
def _should_upgrade(self) -> bool:
upgrade = self._get_upgrade()
if upgrade == b"websocket" and self._should_upgrade_to_ws():
return True
if upgrade is not None:
self._unsupported_upgrade_warning()
return False
def data_received(self, data: bytes) -> None:
self._unset_keepalive_if_required()
@@ -203,10 +200,7 @@ class H11Protocol(asyncio.Protocol):
full_raw_path = self.root_path.encode("ascii") + raw_path
self.scope = {
"type": "http",
"asgi": {
"version": self.config.asgi_version,
"spec_version": "2.3",
},
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
"http_version": event.http_version.decode("ascii"),
"server": self.server,
"client": self.client,
@@ -219,16 +213,13 @@ class H11Protocol(asyncio.Protocol):
"headers": self.headers,
"state": self.app_state.copy(),
}
upgrade = self._get_upgrade()
if upgrade == b"websocket" and self._should_upgrade_to_ws():
if self._should_upgrade():
self.handle_websocket_upgrade(event)
return
# Handle 503 responses when 'limit_concurrency' is exceeded.
if self.limit_concurrency is not None and (
len(self.connections) >= self.limit_concurrency
or len(self.tasks) >= self.limit_concurrency
len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
):
app = service_unavailable
message = "Exceeded concurrency limit."
@@ -275,9 +266,11 @@ class H11Protocol(asyncio.Protocol):
continue
self.cycle.more_body = False
self.cycle.message_event.set()
if self.conn.their_state == h11.MUST_CLOSE:
break
def handle_websocket_upgrade(self, event: h11.Request) -> None:
if self.logger.level <= TRACE_LOG_LEVEL:
if self.logger.level <= TRACE_LOG_LEVEL: # pragma: full coverage
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)
@@ -322,9 +315,7 @@ class H11Protocol(asyncio.Protocol):
# Set a short Keep-Alive timeout.
self._unset_keepalive_if_required()
self.timeout_keep_alive_task = self.loop.call_later(
self.timeout_keep_alive, self.timeout_keep_alive_handler
)
self.timeout_keep_alive_task = self.loop.call_later(self.timeout_keep_alive, self.timeout_keep_alive_handler)
# Unpause data reads if needed.
self.flow.resume_reading()
@@ -349,13 +340,13 @@ class H11Protocol(asyncio.Protocol):
"""
Called by the transport when the write buffer exceeds the high water mark.
"""
self.flow.pause_writing()
self.flow.pause_writing() # pragma: full coverage
def resume_writing(self) -> None:
"""
Called by the transport when the write buffer drops below the low water mark.
"""
self.flow.resume_writing()
self.flow.resume_writing() # pragma: full coverage
def timeout_keep_alive_handler(self) -> None:
"""
@@ -371,7 +362,7 @@ class H11Protocol(asyncio.Protocol):
class RequestResponseCycle:
def __init__(
self,
scope: "HTTPScope",
scope: HTTPScope,
conn: h11.Connection,
transport: asyncio.Transport,
flow: FlowControl,
@@ -407,7 +398,7 @@ class RequestResponseCycle:
self.response_complete = False
# ASGI exception wrapper
async def run_asgi(self, app: "ASGI3Application") -> None:
async def run_asgi(self, app: ASGI3Application) -> None:
try:
result = await app( # type: ignore[func-returns-value]
self.scope, self.receive, self.send
@@ -436,7 +427,7 @@ class RequestResponseCycle:
self.on_response = lambda: None
async def send_500_response(self) -> None:
response_start_event: "HTTPResponseStartEvent" = {
response_start_event: HTTPResponseStartEvent = {
"type": "http.response.start",
"status": 500,
"headers": [
@@ -445,7 +436,7 @@ class RequestResponseCycle:
],
}
await self.send(response_start_event)
response_body_event: "HTTPResponseBodyEvent" = {
response_body_event: HTTPResponseBodyEvent = {
"type": "http.response.body",
"body": b"Internal Server Error",
"more_body": False,
@@ -453,14 +444,14 @@ class RequestResponseCycle:
await self.send(response_body_event)
# ASGI interface
async def send(self, message: "ASGISendEvent") -> None:
async def send(self, message: ASGISendEvent) -> None:
message_type = message["type"]
if self.flow.write_paused and not self.disconnected:
await self.flow.drain()
await self.flow.drain() # pragma: full coverage
if self.disconnected:
return
return # pragma: full coverage
if not self.response_started:
# Sending response status line and headers
@@ -527,12 +518,10 @@ class RequestResponseCycle:
self.transport.close()
self.on_response()
async def receive(self) -> "ASGIReceiveEvent":
async def receive(self) -> ASGIReceiveEvent:
if self.waiting_for_100_continue and not self.transport.is_closing():
headers: list[tuple[str, str]] = []
event = h11.InformationalResponse(
status_code=100, headers=headers, reason="Continue"
)
event = h11.InformationalResponse(status_code=100, headers=headers, reason="Continue")
output = self.conn.send(event=event)
self.transport.write(output)
self.waiting_for_100_continue = False
@@ -545,7 +534,7 @@ class RequestResponseCycle:
if self.disconnected or self.response_complete:
return {"type": "http.disconnect"}
message: "HTTPRequestEvent" = {
message: HTTPRequestEvent = {
"type": "http.request",
"body": self.body,
"more_body": self.more_body,

View File

@@ -7,7 +7,7 @@ import re
import urllib
from asyncio.events import TimerHandle
from collections import deque
from typing import Any, Callable, Deque, Literal, cast
from typing import Any, Callable, Literal, cast
import httptools
@@ -15,31 +15,18 @@ from uvicorn._types import (
ASGI3Application,
ASGIReceiveEvent,
ASGISendEvent,
HTTPDisconnectEvent,
HTTPRequestEvent,
HTTPResponseBodyEvent,
HTTPResponseStartEvent,
HTTPScope,
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.http.flow_control import (
CLOSE_HEADER,
HIGH_WATER_LIMIT,
FlowControl,
service_unavailable,
)
from uvicorn.protocols.utils import (
get_client_addr,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
is_ssl,
)
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
from uvicorn.server import ServerState
HEADER_RE = re.compile(b'[\x00-\x1F\x7F()<>@,;:[]={} \t\\"]')
HEADER_VALUE_RE = re.compile(b"[\x00-\x1F\x7F]")
HEADER_RE = re.compile(b'[\x00-\x1f\x7f()<>@,;:[]={} \t\\"]')
HEADER_VALUE_RE = re.compile(b"[\x00-\x08\x0a-\x1f\x7f]")
def _get_status_line(status_code: int) -> bytes:
@@ -50,9 +37,7 @@ def _get_status_line(status_code: int) -> bytes:
return b"".join([b"HTTP/1.1 ", str(status_code).encode(), b" ", phrase, b"\r\n"])
STATUS_LINE = {
status_code: _get_status_line(status_code) for status_code in range(100, 600)
}
STATUS_LINE = {status_code: _get_status_line(status_code) for status_code in range(100, 600)}
class HttpToolsProtocol(asyncio.Protocol):
@@ -73,6 +58,14 @@ class HttpToolsProtocol(asyncio.Protocol):
self.access_logger = logging.getLogger("uvicorn.access")
self.access_log = self.access_logger.hasHandlers()
self.parser = httptools.HttpRequestParser(self)
try:
# Enable dangerous leniencies to allow server to a response on the first request from a pipelined request.
self.parser.set_dangerous_leniencies(lenient_data_after_close=True)
except AttributeError: # pragma: no cover
# httptools < 0.6.3
pass
self.ws_protocol_class = config.ws_protocol_class
self.root_path = config.root_path
self.limit_concurrency = config.limit_concurrency
@@ -93,7 +86,7 @@ class HttpToolsProtocol(asyncio.Protocol):
self.server: tuple[str, int] | None = None
self.client: tuple[str, int] | None = None
self.scheme: Literal["http", "https"] | None = None
self.pipeline: Deque[tuple[RequestResponseCycle, ASGI3Application]] = deque()
self.pipeline: deque[tuple[RequestResponseCycle, ASGI3Application]] = deque()
# Per-request state
self.scope: HTTPScope = None # type: ignore[assignment]
@@ -154,21 +147,22 @@ class HttpToolsProtocol(asyncio.Protocol):
upgrade = value.lower()
if b"upgrade" in connection:
return upgrade
return None
return None # pragma: full coverage
def _should_upgrade_to_ws(self, upgrade: bytes | None) -> bool:
if upgrade == b"websocket" and self.ws_protocol_class is not None:
return True
if self.config.ws == "auto":
msg = "Unsupported upgrade request."
self.logger.warning(msg)
def _should_upgrade_to_ws(self) -> bool:
if self.ws_protocol_class is None:
return False
return True
def _unsupported_upgrade_warning(self) -> None:
self.logger.warning("Unsupported upgrade request.")
if not self._should_upgrade_to_ws():
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
return False
def _should_upgrade(self) -> bool:
upgrade = self._get_upgrade()
return self._should_upgrade_to_ws(upgrade)
return upgrade == b"websocket" and self._should_upgrade_to_ws()
def data_received(self, data: bytes) -> None:
self._unset_keepalive_if_required()
@@ -181,9 +175,10 @@ class HttpToolsProtocol(asyncio.Protocol):
self.send_400_response(msg)
return
except httptools.HttpParserUpgrade:
upgrade = self._get_upgrade()
if self._should_upgrade_to_ws(upgrade):
if self._should_upgrade():
self.handle_websocket_upgrade()
else:
self._unsupported_upgrade_warning()
def handle_websocket_upgrade(self) -> None:
if self.logger.level <= TRACE_LOG_LEVEL:
@@ -208,7 +203,7 @@ class HttpToolsProtocol(asyncio.Protocol):
def send_400_response(self, msg: str) -> None:
content = [STATUS_LINE[400]]
for name, value in self.server_state.default_headers:
content.extend([name, b": ", value, b"\r\n"])
content.extend([name, b": ", value, b"\r\n"]) # pragma: full coverage
content.extend(
[
b"content-type: text/plain; charset=utf-8\r\n",
@@ -268,8 +263,7 @@ class HttpToolsProtocol(asyncio.Protocol):
# Handle 503 responses when 'limit_concurrency' is exceeded.
if self.limit_concurrency is not None and (
len(self.connections) >= self.limit_concurrency
or len(self.tasks) >= self.limit_concurrency
len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
):
app = service_unavailable
message = "Exceeded concurrency limit."
@@ -302,9 +296,7 @@ class HttpToolsProtocol(asyncio.Protocol):
self.pipeline.appendleft((self.cycle, app))
def on_body(self, body: bytes) -> None:
if (
self.parser.should_upgrade() and self._should_upgrade()
) or self.cycle.response_complete:
if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
return
self.cycle.body += body
if len(self.cycle.body) > HIGH_WATER_LIMIT:
@@ -312,9 +304,7 @@ class HttpToolsProtocol(asyncio.Protocol):
self.cycle.message_event.set()
def on_message_complete(self) -> None:
if (
self.parser.should_upgrade() and self._should_upgrade()
) or self.cycle.response_complete:
if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
return
self.cycle.more_body = False
self.cycle.message_event.set()
@@ -356,13 +346,13 @@ class HttpToolsProtocol(asyncio.Protocol):
"""
Called by the transport when the write buffer exceeds the high water mark.
"""
self.flow.pause_writing()
self.flow.pause_writing() # pragma: full coverage
def resume_writing(self) -> None:
"""
Called by the transport when the write buffer drops below the low water mark.
"""
self.flow.resume_writing()
self.flow.resume_writing() # pragma: full coverage
def timeout_keep_alive_handler(self) -> None:
"""
@@ -376,7 +366,7 @@ class HttpToolsProtocol(asyncio.Protocol):
class RequestResponseCycle:
def __init__(
self,
scope: "HTTPScope",
scope: HTTPScope,
transport: asyncio.Transport,
flow: FlowControl,
logger: logging.Logger,
@@ -414,7 +404,7 @@ class RequestResponseCycle:
self.expected_content_length = 0
# ASGI exception wrapper
async def run_asgi(self, app: "ASGI3Application") -> None:
async def run_asgi(self, app: ASGI3Application) -> None:
try:
result = await app( # type: ignore[func-returns-value]
self.scope, self.receive, self.send
@@ -443,31 +433,28 @@ class RequestResponseCycle:
self.on_response = lambda: None
async def send_500_response(self) -> None:
response_start_event: "HTTPResponseStartEvent" = {
"type": "http.response.start",
"status": 500,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
],
}
await self.send(response_start_event)
response_body_event: "HTTPResponseBodyEvent" = {
"type": "http.response.body",
"body": b"Internal Server Error",
"more_body": False,
}
await self.send(response_body_event)
await self.send(
{
"type": "http.response.start",
"status": 500,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"content-length", b"21"),
(b"connection", b"close"),
],
}
)
await self.send({"type": "http.response.body", "body": b"Internal Server Error", "more_body": False})
# ASGI interface
async def send(self, message: "ASGISendEvent") -> None:
async def send(self, message: ASGISendEvent) -> None:
message_type = message["type"]
if self.flow.write_paused and not self.disconnected:
await self.flow.drain()
await self.flow.drain() # pragma: full coverage
if self.disconnected:
return
return # pragma: full coverage
if not self.response_started:
# Sending response status line and headers
@@ -500,7 +487,7 @@ class RequestResponseCycle:
for name, value in headers:
if HEADER_RE.search(name):
raise RuntimeError("Invalid HTTP header name.")
raise RuntimeError("Invalid HTTP header name.") # pragma: full coverage
if HEADER_VALUE_RE.search(value):
raise RuntimeError("Invalid HTTP header value.")
@@ -515,11 +502,7 @@ class RequestResponseCycle:
self.keep_alive = False
content.extend([name, b": ", value, b"\r\n"])
if (
self.chunked_encoding is None
and self.scope["method"] != "HEAD"
and status_code not in (204, 304)
):
if self.chunked_encoding is None and self.scope["method"] != "HEAD" and status_code not in (204, 304):
# Neither content-length nor transfer-encoding specified
self.chunked_encoding = True
content.append(b"transfer-encoding: chunked\r\n")
@@ -570,7 +553,7 @@ class RequestResponseCycle:
msg = "Unexpected ASGI message '%s' sent, after response already completed."
raise RuntimeError(msg % message_type)
async def receive(self) -> "ASGIReceiveEvent":
async def receive(self) -> ASGIReceiveEvent:
if self.waiting_for_100_continue and not self.transport.is_closing():
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
self.waiting_for_100_continue = False
@@ -580,15 +563,8 @@ class RequestResponseCycle:
await self.message_event.wait()
self.message_event.clear()
message: HTTPDisconnectEvent | HTTPRequestEvent
if self.disconnected or self.response_complete:
message = {"type": "http.disconnect"}
else:
message = {
"type": "http.request",
"body": self.body,
"more_body": self.more_body,
}
self.body = b""
return {"type": "http.disconnect"}
message: HTTPRequestEvent = {"type": "http.request", "body": self.body, "more_body": self.more_body}
self.body = b""
return message

View File

@@ -6,8 +6,7 @@ import urllib.parse
from uvicorn._types import WWWScope
class ClientDisconnected(IOError):
...
class ClientDisconnected(OSError): ...
def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
@@ -53,7 +52,5 @@ def get_client_addr(scope: WWWScope) -> str:
def get_path_with_query_string(scope: WWWScope) -> str:
path_with_query_string = urllib.parse.quote(scope["path"])
if scope["query_string"]:
path_with_query_string = "{}?{}".format(
path_with_query_string, scope["query_string"].decode("ascii")
)
path_with_query_string = "{}?{}".format(path_with_query_string, scope["query_string"].decode("ascii"))
return path_with_query_string

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import asyncio
import typing
AutoWebSocketsProtocol: typing.Optional[typing.Callable[..., asyncio.Protocol]]
AutoWebSocketsProtocol: typing.Callable[..., asyncio.Protocol] | None
try:
import websockets # noqa
except ImportError: # pragma: no cover

View File

@@ -3,27 +3,22 @@ from __future__ import annotations
import asyncio
import http
import logging
from typing import (
Any,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
cast,
)
from collections.abc import Sequence
from typing import Any, Literal, Optional, cast
from urllib.parse import unquote
import websockets
import websockets.legacy.handshake
from websockets.datastructures import Headers
from websockets.exceptions import ConnectionClosed
from websockets.extensions.base import ServerExtensionFactory
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
from websockets.legacy.server import HTTPResponse
from websockets.server import WebSocketServerProtocol
from websockets.typing import Subprotocol
from uvicorn._types import (
ASGI3Application,
ASGISendEvent,
WebSocketAcceptEvent,
WebSocketCloseEvent,
@@ -61,7 +56,8 @@ class Server:
class WebSocketProtocol(WebSocketServerProtocol):
extra_headers: List[Tuple[str, str]]
extra_headers: list[tuple[str, str]]
logger: logging.Logger | logging.LoggerAdapter[Any]
def __init__(
self,
@@ -74,7 +70,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
config.load()
self.config = config
self.app = config.loaded_app
self.app = cast(ASGI3Application, config.loaded_app)
self.loop = _loop or asyncio.get_event_loop()
self.root_path = config.root_path
self.app_state = app_state
@@ -101,7 +97,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.ws_server: Server = Server() # type: ignore[assignment]
extensions = []
extensions: list[ServerExtensionFactory] = []
if self.config.ws_per_message_deflate:
extensions.append(ServerPerMessageDeflateFactory())
@@ -117,8 +113,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
)
self.server_header = None
self.extra_headers = [
(name.decode("latin-1"), value.decode("latin-1"))
for name, value in server_state.default_headers
(name.decode("latin-1"), value.decode("latin-1")) for name, value in server_state.default_headers
]
def connection_made( # type: ignore[override]
@@ -136,16 +131,14 @@ class WebSocketProtocol(WebSocketServerProtocol):
super().connection_made(transport)
def connection_lost(self, exc: Optional[Exception]) -> None:
def connection_lost(self, exc: Exception | None) -> None:
self.connections.remove(self)
if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
self.lost_connection_before_handshake = (
not self.handshake_completed_event.is_set()
)
self.lost_connection_before_handshake = not self.handshake_completed_event.is_set()
self.handshake_completed_event.set()
super().connection_lost(exc)
if exc is None:
@@ -159,12 +152,10 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.send_500_response()
self.transport.close()
def on_task_complete(self, task: asyncio.Task) -> None:
def on_task_complete(self, task: asyncio.Task[None]) -> None:
self.tasks.discard(task)
async def process_request(
self, path: str, headers: Headers
) -> Optional[HTTPResponse]:
async def process_request(self, path: str, request_headers: Headers) -> HTTPResponse | None:
"""
This hook is called to determine if the websocket should return
an HTTP response and close.
@@ -175,15 +166,15 @@ class WebSocketProtocol(WebSocketServerProtocol):
"""
path_portion, _, query_string = path.partition("?")
websockets.legacy.handshake.check_request(headers)
websockets.legacy.handshake.check_request(request_headers)
subprotocols = []
for header in headers.get_all("Sec-WebSocket-Protocol"):
subprotocols: list[str] = []
for header in request_headers.get_all("Sec-WebSocket-Protocol"):
subprotocols.extend([token.strip() for token in header.split(",")])
asgi_headers = [
(name.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
for name, value in headers.raw_items()
for name, value in request_headers.raw_items()
]
path = unquote(path_portion)
full_path = self.root_path + path
@@ -212,8 +203,8 @@ class WebSocketProtocol(WebSocketServerProtocol):
return self.initial_response
def process_subprotocol(
self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]]
) -> Optional[Subprotocol]:
self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None
) -> Subprotocol | None:
"""
We override the standard 'process_subprotocol' behavior here so that
we return whatever subprotocol is sent in the 'accept' message.
@@ -223,8 +214,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
def send_500_response(self) -> None:
msg = b"Internal Server Error"
content = [
b"HTTP/1.1 500 Internal Server Error\r\n"
b"content-type: text/plain; charset=utf-8\r\n",
b"HTTP/1.1 500 Internal Server Error\r\n" b"content-type: text/plain; charset=utf-8\r\n",
b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n",
b"connection: close\r\n",
b"\r\n",
@@ -235,9 +225,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
# itself (see https://github.com/encode/uvicorn/issues/920)
self.handshake_started_event.set()
async def ws_handler( # type: ignore[override]
self, protocol: WebSocketServerProtocol, path: str
) -> Any:
async def ws_handler(self, protocol: WebSocketServerProtocol, path: str) -> Any: # type: ignore[override]
"""
This is the main handler function for the 'websockets' implementation
to call into. We just wait for close then return, and instead allow
@@ -252,14 +240,13 @@ class WebSocketProtocol(WebSocketServerProtocol):
termination states.
"""
try:
result = await self.app(self.scope, self.asgi_receive, self.asgi_send)
except ClientDisconnected:
result = await self.app(self.scope, self.asgi_receive, self.asgi_send) # type: ignore[func-returns-value]
except ClientDisconnected: # pragma: full coverage
self.closed_event.set()
self.transport.close()
except BaseException as exc:
except BaseException:
self.closed_event.set()
msg = "Exception in ASGI application\n"
self.logger.error(msg, exc_info=exc)
self.logger.exception("Exception in ASGI application\n")
if not self.handshake_started_event.is_set():
self.send_500_response()
else:
@@ -268,17 +255,15 @@ class WebSocketProtocol(WebSocketServerProtocol):
else:
self.closed_event.set()
if not self.handshake_started_event.is_set():
msg = "ASGI callable returned without sending handshake."
self.logger.error(msg)
self.logger.error("ASGI callable returned without sending handshake.")
self.send_500_response()
self.transport.close()
elif result is not None:
msg = "ASGI callable should return None, but returned '%s'."
self.logger.error(msg, result)
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
await self.handshake_completed_event.wait()
self.transport.close()
async def asgi_send(self, message: "ASGISendEvent") -> None:
async def asgi_send(self, message: ASGISendEvent) -> None:
message_type = message["type"]
if not self.handshake_started_event.is_set():
@@ -290,9 +275,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
get_path_with_query_string(self.scope),
)
self.initial_response = None
self.accepted_subprotocol = cast(
Optional[Subprotocol], message.get("subprotocol")
)
self.accepted_subprotocol = cast(Optional[Subprotocol], message.get("subprotocol"))
if "headers" in message:
self.extra_headers.extend(
# ASGI spec requires bytes
@@ -324,8 +307,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
# websockets requires the status to be an enum. look it up.
status = http.HTTPStatus(message["status"])
headers = [
(name.decode("latin-1"), value.decode("latin-1"))
for name, value in message.get("headers", [])
(name.decode("latin-1"), value.decode("latin-1")) for name, value in message.get("headers", [])
]
self.initial_response = (status, headers, b"")
self.handshake_started_event.set()
@@ -356,10 +338,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.closed_event.set()
else:
msg = (
"Expected ASGI message 'websocket.send' or 'websocket.close',"
" but got '%s'."
)
msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'."
raise RuntimeError(msg % message_type)
except ConnectionClosed as exc:
raise ClientDisconnected from exc
@@ -372,24 +351,14 @@ class WebSocketProtocol(WebSocketServerProtocol):
if not message.get("more_body", False):
self.closed_event.set()
else:
msg = (
"Expected ASGI message 'websocket.http.response.body' "
"but got '%s'."
)
msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'."
raise RuntimeError(msg % message_type)
else:
msg = (
"Unexpected ASGI message '%s', after sending 'websocket.close' "
"or response already completed."
)
msg = "Unexpected ASGI message '%s', after sending 'websocket.close' " "or response already completed."
raise RuntimeError(msg % message_type)
async def asgi_receive(
self,
) -> Union[
"WebSocketDisconnectEvent", "WebSocketConnectEvent", "WebSocketReceiveEvent"
]:
async def asgi_receive(self) -> WebSocketDisconnectEvent | WebSocketConnectEvent | WebSocketReceiveEvent:
if not self.connect_sent:
self.connect_sent = True
return {"type": "websocket.connect"}
@@ -406,11 +375,11 @@ class WebSocketProtocol(WebSocketServerProtocol):
try:
data = await self.recv()
except ConnectionClosed as exc:
except ConnectionClosed:
self.closed_event.set()
if self.ws_server.closing:
return {"type": "websocket.disconnect", "code": 1012}
return {"type": "websocket.disconnect", "code": exc.code}
return {"type": "websocket.disconnect", "code": self.close_code or 1005, "reason": self.close_reason}
if isinstance(data, str):
return {"type": "websocket.receive", "text": data}

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
import logging
import typing
from typing import Literal
from typing import Literal, cast
from urllib.parse import unquote
import wsproto
@@ -13,6 +13,7 @@ from wsproto.extensions import Extension, PerMessageDeflate
from wsproto.utilities import LocalProtocolError, RemoteProtocolError
from uvicorn._types import (
ASGI3Application,
ASGISendEvent,
WebSocketAcceptEvent,
WebSocketCloseEvent,
@@ -43,10 +44,10 @@ class WSProtocol(asyncio.Protocol):
_loop: asyncio.AbstractEventLoop | None = None,
) -> None:
if not config.loaded:
config.load()
config.load() # pragma: full coverage
self.config = config
self.app = config.loaded_app
self.app = cast(ASGI3Application, config.loaded_app)
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.root_path = config.root_path
@@ -139,13 +140,13 @@ class WSProtocol(asyncio.Protocol):
"""
Called by the transport when the write buffer exceeds the high water mark.
"""
self.writable.clear()
self.writable.clear() # pragma: full coverage
def resume_writing(self) -> None:
"""
Called by the transport when the write buffer drops below the low water mark.
"""
self.writable.set()
self.writable.set() # pragma: full coverage
def shutdown(self) -> None:
if self.handshake_complete:
@@ -156,7 +157,7 @@ class WSProtocol(asyncio.Protocol):
self.send_500_response()
self.transport.close()
def on_task_complete(self, task: asyncio.Task) -> None:
def on_task_complete(self, task: asyncio.Task[None]) -> None:
self.tasks.discard(task)
# Event handlers
@@ -168,7 +169,7 @@ class WSProtocol(asyncio.Protocol):
path = unquote(raw_path)
full_path = self.root_path + path
full_raw_path = self.root_path.encode("ascii") + raw_path.encode("ascii")
self.scope: "WebSocketScope" = {
self.scope: WebSocketScope = {
"type": "websocket",
"asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
"http_version": "1.1",
@@ -211,7 +212,7 @@ class WSProtocol(asyncio.Protocol):
def handle_close(self, event: events.CloseConnection) -> None:
if self.conn.state == ConnectionState.REMOTE_CLOSING:
self.transport.write(self.conn.send(event.response()))
self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code})
self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code, "reason": event.reason})
self.transport.close()
def handle_ping(self, event: events.Ping) -> None:
@@ -220,38 +221,31 @@ class WSProtocol(asyncio.Protocol):
def send_500_response(self) -> None:
if self.response_started or self.handshake_complete:
return # we cannot send responses anymore
headers = [
headers: list[tuple[bytes, bytes]] = [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
(b"content-length", b"21"),
]
output = self.conn.send(
wsproto.events.RejectConnection(
status_code=500, headers=headers, has_body=True
)
)
output += self.conn.send(
wsproto.events.RejectData(data=b"Internal Server Error")
)
output = self.conn.send(wsproto.events.RejectConnection(status_code=500, headers=headers, has_body=True))
output += self.conn.send(wsproto.events.RejectData(data=b"Internal Server Error"))
self.transport.write(output)
async def run_asgi(self) -> None:
try:
result = await self.app(self.scope, self.receive, self.send)
result = await self.app(self.scope, self.receive, self.send) # type: ignore[func-returns-value]
except ClientDisconnected:
self.transport.close()
self.transport.close() # pragma: full coverage
except BaseException:
self.logger.exception("Exception in ASGI application\n")
self.send_500_response()
self.transport.close()
else:
if not self.handshake_complete:
msg = "ASGI callable returned without completing handshake."
self.logger.error(msg)
self.logger.error("ASGI callable returned without completing handshake.")
self.send_500_response()
self.transport.close()
elif result is not None:
msg = "ASGI callable should return None, but returned '%s'."
self.logger.error(msg, result)
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
self.transport.close()
async def send(self, message: ASGISendEvent) -> None:
@@ -269,7 +263,7 @@ class WSProtocol(asyncio.Protocol):
)
subprotocol = message.get("subprotocol")
extra_headers = self.default_headers + list(message.get("headers", []))
extensions: typing.List[Extension] = []
extensions: list[Extension] = []
if self.config.ws_per_message_deflate:
extensions.append(PerMessageDeflate())
if not self.transport.is_closing():
@@ -343,21 +337,14 @@ class WSProtocol(asyncio.Protocol):
self.close_sent = True
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
self.queue.put_nowait(
{"type": "websocket.disconnect", "code": code}
)
output = self.conn.send(
wsproto.events.CloseConnection(code=code, reason=reason)
)
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
output = self.conn.send(wsproto.events.CloseConnection(code=code, reason=reason))
if not self.transport.is_closing():
self.transport.write(output)
self.transport.close()
else:
msg = (
"Expected ASGI message 'websocket.send' or 'websocket.close',"
" but got '%s'."
)
msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'."
raise RuntimeError(msg % message_type)
except LocalProtocolError as exc:
raise ClientDisconnected from exc
@@ -365,24 +352,17 @@ class WSProtocol(asyncio.Protocol):
if message_type == "websocket.http.response.body":
message = typing.cast("WebSocketResponseBodyEvent", message)
body_finished = not message.get("more_body", False)
reject_data = events.RejectData(
data=message["body"], body_finished=body_finished
)
reject_data = events.RejectData(data=message["body"], body_finished=body_finished)
output = self.conn.send(reject_data)
self.transport.write(output)
if body_finished:
self.queue.put_nowait(
{"type": "websocket.disconnect", "code": 1006}
)
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.close_sent = True
self.transport.close()
else:
msg = (
"Expected ASGI message 'websocket.http.response.body' "
"but got '%s'."
)
msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'."
raise RuntimeError(msg % message_type)
else:

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import contextlib
import logging
import os
import platform
@@ -9,9 +10,10 @@ import socket
import sys
import threading
import time
from collections.abc import Generator, Sequence
from email.utils import formatdate
from types import FrameType
from typing import TYPE_CHECKING, Sequence, Union
from typing import TYPE_CHECKING, Union
import click
@@ -57,11 +59,17 @@ class Server:
self.force_exit = False
self.last_notified = 0.0
self._captured_signals: list[int] = []
def run(self, sockets: list[socket.socket] | None = None) -> None:
self.config.setup_event_loop()
return asyncio.run(self.serve(sockets=sockets))
async def serve(self, sockets: list[socket.socket] | None = None) -> None:
with self.capture_signals():
await self._serve(sockets)
async def _serve(self, sockets: list[socket.socket] | None = None) -> None:
process_id = os.getpid()
config = self.config
@@ -70,8 +78,6 @@ class Server:
self.lifespan = config.lifespan_class(config)
self.install_signal_handlers()
message = "Started server process [%d]"
color_message = "Started server process [" + click.style("%d", fg="cyan") + "]"
logger.info(message, process_id, extra={"color_message": color_message})
@@ -107,7 +113,7 @@ class Server:
loop = asyncio.get_running_loop()
listeners: Sequence[socket.SocketType]
if sockets is not None:
if sockets is not None: # pragma: full coverage
# Explicitly passed a list of open sockets.
# We use this when the server is run from a Gunicorn worker.
@@ -126,18 +132,14 @@ class Server:
is_windows = platform.system() == "Windows"
if config.workers > 1 and is_windows: # pragma: py-not-win32
sock = _share_socket(sock) # type: ignore[assignment]
server = await loop.create_server(
create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog
)
server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog)
self.servers.append(server)
listeners = sockets
elif config.fd is not None: # pragma: py-win32
# Use an existing socket, from a file descriptor.
sock = socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM)
server = await loop.create_server(
create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog
)
server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog)
assert server.sockets is not None # mypy
listeners = server.sockets
self.servers = [server]
@@ -146,7 +148,7 @@ class Server:
# Create a socket using UNIX domain socket.
uds_perms = 0o666
if os.path.exists(config.uds):
uds_perms = os.stat(config.uds).st_mode
uds_perms = os.stat(config.uds).st_mode # pragma: full coverage
server = await loop.create_unix_server(
create_protocol, path=config.uds, ssl=config.ssl, backlog=config.backlog
)
@@ -179,7 +181,7 @@ class Server:
else:
# We're most likely running multiple workers, so a message has already been
# logged by `config.bind_socket()`.
pass
pass # pragma: full coverage
self.started = True
@@ -194,9 +196,7 @@ class Server:
)
elif config.uds is not None: # pragma: py-win32
logger.info(
"Uvicorn running on unix socket %s (Press CTRL+C to quit)", config.uds
)
logger.info("Uvicorn running on unix socket %s (Press CTRL+C to quit)", config.uds)
else:
addr_format = "%s://%s:%d"
@@ -211,11 +211,7 @@ class Server:
protocol_name = "https" if config.ssl else "http"
message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)"
color_message = (
"Uvicorn running on "
+ click.style(addr_format, bold=True)
+ " (Press CTRL+C to quit)"
)
color_message = "Uvicorn running on " + click.style(addr_format, bold=True) + " (Press CTRL+C to quit)"
logger.info(
message,
protocol_name,
@@ -244,21 +240,23 @@ class Server:
else:
date_header = []
self.server_state.default_headers = (
date_header + self.config.encoded_headers
)
self.server_state.default_headers = date_header + self.config.encoded_headers
# Callback to `callback_notify` once every `timeout_notify` seconds.
if self.config.callback_notify is not None:
if current_time - self.last_notified > self.config.timeout_notify:
if current_time - self.last_notified > self.config.timeout_notify: # pragma: full coverage
self.last_notified = current_time
await self.config.callback_notify()
# Determine if we should exit.
if self.should_exit:
return True
if self.config.limit_max_requests is not None:
return self.server_state.total_requests >= self.config.limit_max_requests
max_requests = self.config.limit_max_requests
if max_requests is not None and self.server_state.total_requests >= max_requests:
logger.warning(f"Maximum request limit of {max_requests} exceeded. Terminating process.")
return True
return False
async def shutdown(self, sockets: list[socket.socket] | None = None) -> None:
@@ -268,7 +266,7 @@ class Server:
for server in self.servers:
server.close()
for sock in sockets or []:
sock.close()
sock.close() # pragma: full coverage
# Request shutdown on all existing connections.
for connection in list(self.server_state.connections):
@@ -287,10 +285,7 @@ class Server:
len(self.server_state.tasks),
)
for t in self.server_state.tasks:
if sys.version_info < (3, 9): # pragma: py-gte-39
t.cancel()
else: # pragma: py-lt-39
t.cancel(msg="Task cancelled, timeout graceful shutdown exceeded")
t.cancel(msg="Task cancelled, timeout graceful shutdown exceeded")
# Send the lifespan shutdown event, and wait for application shutdown.
if not self.force_exit:
@@ -314,23 +309,29 @@ class Server:
for server in self.servers:
await server.wait_closed()
def install_signal_handlers(self) -> None:
@contextlib.contextmanager
def capture_signals(self) -> Generator[None, None, None]:
# Signals can only be listened to from the main thread.
if threading.current_thread() is not threading.main_thread():
# Signals can only be listened to from the main thread.
yield
return
loop = asyncio.get_event_loop()
# always use signal.signal, even if loop.add_signal_handler is available
# this allows to restore previous signal handlers later on
original_handlers = {sig: signal.signal(sig, self.handle_exit) for sig in HANDLED_SIGNALS}
try:
for sig in HANDLED_SIGNALS:
loop.add_signal_handler(sig, self.handle_exit, sig, None)
except NotImplementedError: # pragma: no cover
# Windows
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.handle_exit)
yield
finally:
for sig, handler in original_handlers.items():
signal.signal(sig, handler)
# If we did gracefully shut down due to a signal, try to
# trigger the expected behaviour now; multiple signals would be
# done LIFO, see https://stackoverflow.com/questions/48434964
for captured_signal in reversed(self._captured_signals):
signal.raise_signal(captured_signal)
def handle_exit(self, sig: int, frame: FrameType | None) -> None:
self._captured_signals.append(sig)
if self.should_exit and sig == signal.SIGINT:
self.force_exit = True
self.force_exit = True # pragma: full coverage
else:
self.should_exit = True

View File

@@ -1,21 +1,16 @@
from typing import TYPE_CHECKING, Type
from __future__ import annotations
from typing import TYPE_CHECKING
from uvicorn.supervisors.basereload import BaseReload
from uvicorn.supervisors.multiprocess import Multiprocess
if TYPE_CHECKING:
ChangeReload: Type[BaseReload]
ChangeReload: type[BaseReload]
else:
try:
from uvicorn.supervisors.watchfilesreload import (
WatchFilesReload as ChangeReload,
)
from uvicorn.supervisors.watchfilesreload import WatchFilesReload as ChangeReload
except ImportError: # pragma: no cover
try:
from uvicorn.supervisors.watchgodreload import (
WatchGodReload as ChangeReload,
)
except ImportError:
from uvicorn.supervisors.statreload import StatReload as ChangeReload
from uvicorn.supervisors.statreload import StatReload as ChangeReload
__all__ = ["Multiprocess", "ChangeReload"]

View File

@@ -5,10 +5,11 @@ import os
import signal
import sys
import threading
from collections.abc import Iterator
from pathlib import Path
from socket import socket
from types import FrameType
from typing import Callable, Iterator
from typing import Callable
import click
@@ -38,14 +39,14 @@ class BaseReload:
self.is_restarting = False
self.reloader_name: str | None = None
def signal_handler(self, sig: int, frame: FrameType | None) -> None:
def signal_handler(self, sig: int, frame: FrameType | None) -> None: # pragma: full coverage
"""
A signal handler that is registered with the parent process.
"""
if sys.platform == "win32" and self.is_restarting:
self.is_restarting = False # pragma: py-not-win32
self.is_restarting = False
else:
self.should_exit.set() # pragma: py-win32
self.should_exit.set()
def run(self) -> None:
self.startup()
@@ -81,9 +82,7 @@ class BaseReload:
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.signal_handler)
self.process = get_subprocess(
config=self.config, target=self.target, sockets=self.sockets
)
self.process = get_subprocess(config=self.config, target=self.target, sockets=self.sockets)
self.process.start()
def restart(self) -> None:
@@ -95,9 +94,7 @@ class BaseReload:
self.process.terminate()
self.process.join()
self.process = get_subprocess(
config=self.config, target=self.target, sockets=self.sockets
)
self.process = get_subprocess(config=self.config, target=self.target, sockets=self.sockets)
self.process.start()
def shutdown(self) -> None:
@@ -110,10 +107,8 @@ class BaseReload:
for sock in self.sockets:
sock.close()
message = "Stopping reloader process [{}]".format(str(self.pid))
color_message = "Stopping reloader process [{}]".format(
click.style(str(self.pid), fg="cyan", bold=True)
)
message = f"Stopping reloader process [{str(self.pid)}]"
color_message = "Stopping reloader process [{}]".format(click.style(str(self.pid), fg="cyan", bold=True))
logger.info(message, extra={"color_message": color_message})
def should_restart(self) -> list[Path] | None:

View File

@@ -4,24 +4,101 @@ import logging
import os
import signal
import threading
from multiprocessing.context import SpawnProcess
from multiprocessing import Pipe
from socket import socket
from types import FrameType
from typing import Callable
from typing import Any, Callable
import click
from uvicorn._subprocess import get_subprocess
from uvicorn.config import Config
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
)
SIGNALS = {
getattr(signal, f"SIG{x}"): x
for x in "INT TERM BREAK HUP QUIT TTIN TTOU USR1 USR2 WINCH".split()
if hasattr(signal, f"SIG{x}")
}
logger = logging.getLogger("uvicorn.error")
class Process:
def __init__(
self,
config: Config,
target: Callable[[list[socket] | None], None],
sockets: list[socket],
) -> None:
self.real_target = target
self.parent_conn, self.child_conn = Pipe()
self.process = get_subprocess(config, self.target, sockets)
def ping(self, timeout: float = 5) -> bool:
self.parent_conn.send(b"ping")
if self.parent_conn.poll(timeout):
self.parent_conn.recv()
return True
return False
def pong(self) -> None:
self.child_conn.recv()
self.child_conn.send(b"pong")
def always_pong(self) -> None:
while True:
self.pong()
def target(self, sockets: list[socket] | None = None) -> Any: # pragma: no cover
if os.name == "nt": # pragma: py-not-win32
# Windows doesn't support SIGTERM, so we use SIGBREAK instead.
# And then we raise SIGTERM when SIGBREAK is received.
# https://learn.microsoft.com/zh-cn/cpp/c-runtime-library/reference/signal?view=msvc-170
signal.signal(
signal.SIGBREAK, # type: ignore[attr-defined]
lambda sig, frame: signal.raise_signal(signal.SIGTERM),
)
threading.Thread(target=self.always_pong, daemon=True).start()
return self.real_target(sockets)
def is_alive(self, timeout: float = 5) -> bool:
if not self.process.is_alive():
return False # pragma: full coverage
return self.ping(timeout)
def start(self) -> None:
self.process.start()
def terminate(self) -> None:
if self.process.exitcode is None: # Process is still running
assert self.process.pid is not None
if os.name == "nt": # pragma: py-not-win32
# Windows doesn't support SIGTERM.
# So send SIGBREAK, and then in process raise SIGTERM.
os.kill(self.process.pid, signal.CTRL_BREAK_EVENT) # type: ignore[attr-defined]
else:
os.kill(self.process.pid, signal.SIGTERM)
logger.info(f"Terminated child process [{self.process.pid}]")
self.parent_conn.close()
self.child_conn.close()
def kill(self) -> None:
# In Windows, the method will call `TerminateProcess` to kill the process.
# In Unix, the method will send SIGKILL to the process.
self.process.kill()
def join(self) -> None:
logger.info(f"Waiting for child process [{self.process.pid}]")
self.process.join()
@property
def pid(self) -> int | None:
return self.process.pid
class Multiprocess:
def __init__(
self,
@@ -32,45 +109,114 @@ class Multiprocess:
self.config = config
self.target = target
self.sockets = sockets
self.processes: list[SpawnProcess] = []
self.processes_num = config.workers
self.processes: list[Process] = []
self.should_exit = threading.Event()
self.pid = os.getpid()
def signal_handler(self, sig: int, frame: FrameType | None) -> None:
"""
A signal handler that is registered with the parent process.
"""
self.should_exit.set()
self.signal_queue: list[int] = []
for sig in SIGNALS:
signal.signal(sig, lambda sig, frame: self.signal_queue.append(sig))
def run(self) -> None:
self.startup()
self.should_exit.wait()
self.shutdown()
def startup(self) -> None:
message = "Started parent process [{}]".format(str(self.pid))
color_message = "Started parent process [{}]".format(
click.style(str(self.pid), fg="cyan", bold=True)
)
logger.info(message, extra={"color_message": color_message})
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.signal_handler)
for _idx in range(self.config.workers):
process = get_subprocess(
config=self.config, target=self.target, sockets=self.sockets
)
def init_processes(self) -> None:
for _ in range(self.processes_num):
process = Process(self.config, self.target, self.sockets)
process.start()
self.processes.append(process)
def shutdown(self) -> None:
def terminate_all(self) -> None:
for process in self.processes:
process.terminate()
def join_all(self) -> None:
for process in self.processes:
process.join()
message = "Stopping parent process [{}]".format(str(self.pid))
color_message = "Stopping parent process [{}]".format(
click.style(str(self.pid), fg="cyan", bold=True)
)
def restart_all(self) -> None:
for idx, process in enumerate(self.processes):
process.terminate()
process.join()
new_process = Process(self.config, self.target, self.sockets)
new_process.start()
self.processes[idx] = new_process
def run(self) -> None:
message = f"Started parent process [{os.getpid()}]"
color_message = "Started parent process [{}]".format(click.style(str(os.getpid()), fg="cyan", bold=True))
logger.info(message, extra={"color_message": color_message})
self.init_processes()
while not self.should_exit.wait(0.5):
self.handle_signals()
self.keep_subprocess_alive()
self.terminate_all()
self.join_all()
message = f"Stopping parent process [{os.getpid()}]"
color_message = "Stopping parent process [{}]".format(click.style(str(os.getpid()), fg="cyan", bold=True))
logger.info(message, extra={"color_message": color_message})
def keep_subprocess_alive(self) -> None:
if self.should_exit.is_set():
return # parent process is exiting, no need to keep subprocess alive
for idx, process in enumerate(self.processes):
if process.is_alive():
continue
process.kill() # process is hung, kill it
process.join()
if self.should_exit.is_set():
return # pragma: full coverage
logger.info(f"Child process [{process.pid}] died")
process = Process(self.config, self.target, self.sockets)
process.start()
self.processes[idx] = process
def handle_signals(self) -> None:
for sig in tuple(self.signal_queue):
self.signal_queue.remove(sig)
sig_name = SIGNALS[sig]
sig_handler = getattr(self, f"handle_{sig_name.lower()}", None)
if sig_handler is not None:
sig_handler()
else: # pragma: no cover
logger.debug(f"Received signal {sig_name}, but no handler is defined for it.")
def handle_int(self) -> None:
logger.info("Received SIGINT, exiting.")
self.should_exit.set()
def handle_term(self) -> None:
logger.info("Received SIGTERM, exiting.")
self.should_exit.set()
def handle_break(self) -> None: # pragma: py-not-win32
logger.info("Received SIGBREAK, exiting.")
self.should_exit.set()
def handle_hup(self) -> None: # pragma: py-win32
logger.info("Received SIGHUP, restarting processes.")
self.restart_all()
def handle_ttin(self) -> None: # pragma: py-win32
logger.info("Received SIGTTIN, increasing the number of processes.")
self.processes_num += 1
process = Process(self.config, self.target, self.sockets)
process.start()
self.processes.append(process)
def handle_ttou(self) -> None: # pragma: py-win32
logger.info("Received SIGTTOU, decreasing number of processes.")
if self.processes_num <= 1:
logger.info("Already reached one process, cannot decrease the number of processes anymore.")
return
self.processes_num -= 1
process = self.processes.pop()
process.terminate()
process.join()

View File

@@ -1,9 +1,10 @@
from __future__ import annotations
import logging
from collections.abc import Iterator
from pathlib import Path
from socket import socket
from typing import Callable, Iterator
from typing import Callable
from uvicorn.config import Config
from uvicorn.supervisors.basereload import BaseReload
@@ -23,10 +24,7 @@ class StatReload(BaseReload):
self.mtimes: dict[Path, float] = {}
if config.reload_excludes or config.reload_includes:
logger.warning(
"--reload-include and --reload-exclude have no effect unless "
"watchfiles is installed."
)
logger.warning("--reload-include and --reload-exclude have no effect unless " "watchfiles is installed.")
def should_restart(self) -> list[Path] | None:
self.pause()

View File

@@ -13,20 +13,12 @@ from uvicorn.supervisors.basereload import BaseReload
class FileFilter:
def __init__(self, config: Config):
default_includes = ["*.py"]
self.includes = [
default
for default in default_includes
if default not in config.reload_excludes
]
self.includes = [default for default in default_includes if default not in config.reload_excludes]
self.includes.extend(config.reload_includes)
self.includes = list(set(self.includes))
default_excludes = [".*", ".py[cod]", ".sw.*", "~*"]
self.excludes = [
default
for default in default_excludes
if default not in config.reload_includes
]
self.excludes = [default for default in default_excludes if default not in config.reload_includes]
self.exclude_dirs = []
for e in config.reload_excludes:
p = Path(e)
@@ -39,14 +31,14 @@ class FileFilter:
if is_dir:
self.exclude_dirs.append(p)
else:
self.excludes.append(e)
self.excludes.append(e) # pragma: full coverage
self.excludes = list(set(self.excludes))
def __call__(self, path: Path) -> bool:
for include_pattern in self.includes:
if path.match(include_pattern):
if str(path).endswith(include_pattern):
return True
return True # pragma: full coverage
for exclude_dir in self.exclude_dirs:
if exclude_dir in path.parents:
@@ -54,7 +46,7 @@ class FileFilter:
for exclude_pattern in self.excludes:
if path.match(exclude_pattern):
return False
return False # pragma: full coverage
return True
return False

View File

@@ -1,163 +0,0 @@
from __future__ import annotations
import logging
import warnings
from pathlib import Path
from socket import socket
from typing import TYPE_CHECKING, Callable
from watchgod import DefaultWatcher
from uvicorn.config import Config
from uvicorn.supervisors.basereload import BaseReload
if TYPE_CHECKING:
import os
DirEntry = os.DirEntry[str]
logger = logging.getLogger("uvicorn.error")
class CustomWatcher(DefaultWatcher):
def __init__(self, root_path: Path, config: Config):
default_includes = ["*.py"]
self.includes = [
default
for default in default_includes
if default not in config.reload_excludes
]
self.includes.extend(config.reload_includes)
self.includes = list(set(self.includes))
default_excludes = [".*", ".py[cod]", ".sw.*", "~*"]
self.excludes = [
default
for default in default_excludes
if default not in config.reload_includes
]
self.excludes.extend(config.reload_excludes)
self.excludes = list(set(self.excludes))
self.watched_dirs: dict[str, bool] = {}
self.watched_files: dict[str, bool] = {}
self.dirs_includes = set(config.reload_dirs)
self.dirs_excludes = set(config.reload_dirs_excludes)
self.resolved_root = root_path
super().__init__(str(root_path))
def should_watch_file(self, entry: "DirEntry") -> bool:
cached_result = self.watched_files.get(entry.path)
if cached_result is not None:
return cached_result
entry_path = Path(entry)
# cwd is not verified through should_watch_dir, so we need to verify here
if entry_path.parent == Path.cwd() and Path.cwd() not in self.dirs_includes:
self.watched_files[entry.path] = False
return False
for include_pattern in self.includes:
if str(entry_path).endswith(include_pattern):
self.watched_files[entry.path] = True
return True
if entry_path.match(include_pattern):
for exclude_pattern in self.excludes:
if entry_path.match(exclude_pattern):
self.watched_files[entry.path] = False
return False
self.watched_files[entry.path] = True
return True
self.watched_files[entry.path] = False
return False
def should_watch_dir(self, entry: "DirEntry") -> bool:
cached_result = self.watched_dirs.get(entry.path)
if cached_result is not None:
return cached_result
entry_path = Path(entry)
if entry_path in self.dirs_excludes:
self.watched_dirs[entry.path] = False
return False
for exclude_pattern in self.excludes:
if entry_path.match(exclude_pattern):
is_watched = False
if entry_path in self.dirs_includes:
is_watched = True
for directory in self.dirs_includes:
if directory in entry_path.parents:
is_watched = True
if is_watched:
logger.debug(
"WatchGodReload detected a new excluded dir '%s' in '%s'; "
"Adding to exclude list.",
entry_path.relative_to(self.resolved_root),
str(self.resolved_root),
)
self.watched_dirs[entry.path] = False
self.dirs_excludes.add(entry_path)
return False
if entry_path in self.dirs_includes:
self.watched_dirs[entry.path] = True
return True
for directory in self.dirs_includes:
if directory in entry_path.parents:
self.watched_dirs[entry.path] = True
return True
for include_pattern in self.includes:
if entry_path.match(include_pattern):
logger.info(
"WatchGodReload detected a new reload dir '%s' in '%s'; "
"Adding to watch list.",
str(entry_path.relative_to(self.resolved_root)),
str(self.resolved_root),
)
self.dirs_includes.add(entry_path)
self.watched_dirs[entry.path] = True
return True
self.watched_dirs[entry.path] = False
return False
class WatchGodReload(BaseReload):
def __init__(
self,
config: Config,
target: Callable[[list[socket] | None], None],
sockets: list[socket],
) -> None:
warnings.warn(
'"watchgod" is deprecated, you should switch '
"to watchfiles (`pip install watchfiles`).",
DeprecationWarning,
)
super().__init__(config, target, sockets)
self.reloader_name = "WatchGod"
self.watchers = []
reload_dirs = []
for directory in config.reload_dirs:
if Path.cwd() not in directory.parents:
reload_dirs.append(directory)
if Path.cwd() not in reload_dirs:
reload_dirs.append(Path.cwd())
for w in reload_dirs:
self.watchers.append(CustomWatcher(w.resolve(), self.config))
def should_restart(self) -> list[Path] | None:
self.pause()
for watcher in self.watchers:
change = watcher.check()
if change != set():
return list({Path(c[1]) for c in change})
return None

View File

@@ -1,14 +1,23 @@
from __future__ import annotations
import asyncio
import logging
import signal
import sys
from typing import Any, Dict
import warnings
from typing import Any
from gunicorn.arbiter import Arbiter
from gunicorn.workers.base import Worker
from uvicorn.config import Config
from uvicorn.main import Server
from uvicorn.server import Server
warnings.warn(
"The `uvicorn.workers` module is deprecated. Please use `uvicorn-worker` package instead.\n"
"For more details, see https://github.com/Kludex/uvicorn-worker.",
DeprecationWarning,
)
class UvicornWorker(Worker):
@@ -17,10 +26,10 @@ class UvicornWorker(Worker):
rather than a WSGI callable.
"""
CONFIG_KWARGS: Dict[str, Any] = {"loop": "auto", "http": "auto"}
CONFIG_KWARGS: dict[str, Any] = {"loop": "auto", "http": "auto"}
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(UvicornWorker, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
logger = logging.getLogger("uvicorn.error")
logger.handlers = self.log.error_log.handlers
@@ -63,7 +72,7 @@ class UvicornWorker(Worker):
def init_process(self) -> None:
self.config.setup_event_loop()
super(UvicornWorker, self).init_process()
super().init_process()
def init_signals(self) -> None:
# Reset signals so Gunicorn doesn't swallow subprocess return codes