1
0
mirror of https://gitlab.com/MoonTestUse1/AdministrationItDepartmens.git synced 2025-08-14 00:25:46 +02:00

Проверка 09.02.2025

This commit is contained in:
MoonTestUse1
2025-02-09 01:11:49 +06:00
parent ce52f8a23a
commit 0aa3ef8fc2
5827 changed files with 14316 additions and 1906434 deletions

View File

@@ -1,12 +1,6 @@
import asyncio
from uvicorn._types import (
ASGIReceiveCallable,
ASGISendCallable,
HTTPResponseBodyEvent,
HTTPResponseStartEvent,
Scope,
)
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
CLOSE_HEADER = (b"connection", b"close")
@@ -22,7 +16,7 @@ class FlowControl:
self._is_writable_event.set()
async def drain(self) -> None:
await self._is_writable_event.wait()
await self._is_writable_event.wait() # pragma: full coverage
def pause_reading(self) -> None:
if not self.read_paused:
@@ -35,32 +29,26 @@ class FlowControl:
self._transport.resume_reading()
def pause_writing(self) -> None:
if not self.write_paused:
if not self.write_paused: # pragma: full coverage
self.write_paused = True
self._is_writable_event.clear()
def resume_writing(self) -> None:
if self.write_paused:
if self.write_paused: # pragma: full coverage
self.write_paused = False
self._is_writable_event.set()
async def service_unavailable(
scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
) -> None:
response_start: "HTTPResponseStartEvent" = {
"type": "http.response.start",
"status": 503,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
],
}
await send(response_start)
response_body: "HTTPResponseBodyEvent" = {
"type": "http.response.body",
"body": b"Service Unavailable",
"more_body": False,
}
await send(response_body)
async def service_unavailable(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
await send(
{
"type": "http.response.start",
"status": 503,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"content-length", b"19"),
(b"connection", b"close"),
],
}
)
await send({"type": "http.response.body", "body": b"Service Unavailable", "more_body": False})

View File

@@ -20,19 +20,8 @@ from uvicorn._types import (
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.http.flow_control import (
CLOSE_HEADER,
HIGH_WATER_LIMIT,
FlowControl,
service_unavailable,
)
from uvicorn.protocols.utils import (
get_client_addr,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
is_ssl,
)
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
from uvicorn.server import ServerState
@@ -43,9 +32,7 @@ def _get_status_phrase(status_code: int) -> bytes:
return b""
STATUS_PHRASES = {
status_code: _get_status_phrase(status_code) for status_code in range(100, 600)
}
STATUS_PHRASES = {status_code: _get_status_phrase(status_code) for status_code in range(100, 600)}
class H11Protocol(asyncio.Protocol):
@@ -160,14 +147,24 @@ class H11Protocol(asyncio.Protocol):
def _should_upgrade_to_ws(self) -> bool:
if self.ws_protocol_class is None:
if self.config.ws == "auto":
msg = "Unsupported upgrade request."
self.logger.warning(msg)
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
return False
return True
def _unsupported_upgrade_warning(self) -> None:
msg = "Unsupported upgrade request."
self.logger.warning(msg)
if not self._should_upgrade_to_ws():
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
def _should_upgrade(self) -> bool:
upgrade = self._get_upgrade()
if upgrade == b"websocket" and self._should_upgrade_to_ws():
return True
if upgrade is not None:
self._unsupported_upgrade_warning()
return False
def data_received(self, data: bytes) -> None:
self._unset_keepalive_if_required()
@@ -203,10 +200,7 @@ class H11Protocol(asyncio.Protocol):
full_raw_path = self.root_path.encode("ascii") + raw_path
self.scope = {
"type": "http",
"asgi": {
"version": self.config.asgi_version,
"spec_version": "2.3",
},
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
"http_version": event.http_version.decode("ascii"),
"server": self.server,
"client": self.client,
@@ -219,16 +213,13 @@ class H11Protocol(asyncio.Protocol):
"headers": self.headers,
"state": self.app_state.copy(),
}
upgrade = self._get_upgrade()
if upgrade == b"websocket" and self._should_upgrade_to_ws():
if self._should_upgrade():
self.handle_websocket_upgrade(event)
return
# Handle 503 responses when 'limit_concurrency' is exceeded.
if self.limit_concurrency is not None and (
len(self.connections) >= self.limit_concurrency
or len(self.tasks) >= self.limit_concurrency
len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
):
app = service_unavailable
message = "Exceeded concurrency limit."
@@ -275,9 +266,11 @@ class H11Protocol(asyncio.Protocol):
continue
self.cycle.more_body = False
self.cycle.message_event.set()
if self.conn.their_state == h11.MUST_CLOSE:
break
def handle_websocket_upgrade(self, event: h11.Request) -> None:
if self.logger.level <= TRACE_LOG_LEVEL:
if self.logger.level <= TRACE_LOG_LEVEL: # pragma: full coverage
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)
@@ -322,9 +315,7 @@ class H11Protocol(asyncio.Protocol):
# Set a short Keep-Alive timeout.
self._unset_keepalive_if_required()
self.timeout_keep_alive_task = self.loop.call_later(
self.timeout_keep_alive, self.timeout_keep_alive_handler
)
self.timeout_keep_alive_task = self.loop.call_later(self.timeout_keep_alive, self.timeout_keep_alive_handler)
# Unpause data reads if needed.
self.flow.resume_reading()
@@ -349,13 +340,13 @@ class H11Protocol(asyncio.Protocol):
"""
Called by the transport when the write buffer exceeds the high water mark.
"""
self.flow.pause_writing()
self.flow.pause_writing() # pragma: full coverage
def resume_writing(self) -> None:
"""
Called by the transport when the write buffer drops below the low water mark.
"""
self.flow.resume_writing()
self.flow.resume_writing() # pragma: full coverage
def timeout_keep_alive_handler(self) -> None:
"""
@@ -371,7 +362,7 @@ class H11Protocol(asyncio.Protocol):
class RequestResponseCycle:
def __init__(
self,
scope: "HTTPScope",
scope: HTTPScope,
conn: h11.Connection,
transport: asyncio.Transport,
flow: FlowControl,
@@ -407,7 +398,7 @@ class RequestResponseCycle:
self.response_complete = False
# ASGI exception wrapper
async def run_asgi(self, app: "ASGI3Application") -> None:
async def run_asgi(self, app: ASGI3Application) -> None:
try:
result = await app( # type: ignore[func-returns-value]
self.scope, self.receive, self.send
@@ -436,7 +427,7 @@ class RequestResponseCycle:
self.on_response = lambda: None
async def send_500_response(self) -> None:
response_start_event: "HTTPResponseStartEvent" = {
response_start_event: HTTPResponseStartEvent = {
"type": "http.response.start",
"status": 500,
"headers": [
@@ -445,7 +436,7 @@ class RequestResponseCycle:
],
}
await self.send(response_start_event)
response_body_event: "HTTPResponseBodyEvent" = {
response_body_event: HTTPResponseBodyEvent = {
"type": "http.response.body",
"body": b"Internal Server Error",
"more_body": False,
@@ -453,14 +444,14 @@ class RequestResponseCycle:
await self.send(response_body_event)
# ASGI interface
async def send(self, message: "ASGISendEvent") -> None:
async def send(self, message: ASGISendEvent) -> None:
message_type = message["type"]
if self.flow.write_paused and not self.disconnected:
await self.flow.drain()
await self.flow.drain() # pragma: full coverage
if self.disconnected:
return
return # pragma: full coverage
if not self.response_started:
# Sending response status line and headers
@@ -527,12 +518,10 @@ class RequestResponseCycle:
self.transport.close()
self.on_response()
async def receive(self) -> "ASGIReceiveEvent":
async def receive(self) -> ASGIReceiveEvent:
if self.waiting_for_100_continue and not self.transport.is_closing():
headers: list[tuple[str, str]] = []
event = h11.InformationalResponse(
status_code=100, headers=headers, reason="Continue"
)
event = h11.InformationalResponse(status_code=100, headers=headers, reason="Continue")
output = self.conn.send(event=event)
self.transport.write(output)
self.waiting_for_100_continue = False
@@ -545,7 +534,7 @@ class RequestResponseCycle:
if self.disconnected or self.response_complete:
return {"type": "http.disconnect"}
message: "HTTPRequestEvent" = {
message: HTTPRequestEvent = {
"type": "http.request",
"body": self.body,
"more_body": self.more_body,

View File

@@ -7,7 +7,7 @@ import re
import urllib
from asyncio.events import TimerHandle
from collections import deque
from typing import Any, Callable, Deque, Literal, cast
from typing import Any, Callable, Literal, cast
import httptools
@@ -15,31 +15,18 @@ from uvicorn._types import (
ASGI3Application,
ASGIReceiveEvent,
ASGISendEvent,
HTTPDisconnectEvent,
HTTPRequestEvent,
HTTPResponseBodyEvent,
HTTPResponseStartEvent,
HTTPScope,
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.http.flow_control import (
CLOSE_HEADER,
HIGH_WATER_LIMIT,
FlowControl,
service_unavailable,
)
from uvicorn.protocols.utils import (
get_client_addr,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
is_ssl,
)
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
from uvicorn.server import ServerState
HEADER_RE = re.compile(b'[\x00-\x1F\x7F()<>@,;:[]={} \t\\"]')
HEADER_VALUE_RE = re.compile(b"[\x00-\x1F\x7F]")
HEADER_RE = re.compile(b'[\x00-\x1f\x7f()<>@,;:[]={} \t\\"]')
HEADER_VALUE_RE = re.compile(b"[\x00-\x08\x0a-\x1f\x7f]")
def _get_status_line(status_code: int) -> bytes:
@@ -50,9 +37,7 @@ def _get_status_line(status_code: int) -> bytes:
return b"".join([b"HTTP/1.1 ", str(status_code).encode(), b" ", phrase, b"\r\n"])
STATUS_LINE = {
status_code: _get_status_line(status_code) for status_code in range(100, 600)
}
STATUS_LINE = {status_code: _get_status_line(status_code) for status_code in range(100, 600)}
class HttpToolsProtocol(asyncio.Protocol):
@@ -73,6 +58,14 @@ class HttpToolsProtocol(asyncio.Protocol):
self.access_logger = logging.getLogger("uvicorn.access")
self.access_log = self.access_logger.hasHandlers()
self.parser = httptools.HttpRequestParser(self)
try:
# Enable dangerous leniencies to allow server to a response on the first request from a pipelined request.
self.parser.set_dangerous_leniencies(lenient_data_after_close=True)
except AttributeError: # pragma: no cover
# httptools < 0.6.3
pass
self.ws_protocol_class = config.ws_protocol_class
self.root_path = config.root_path
self.limit_concurrency = config.limit_concurrency
@@ -93,7 +86,7 @@ class HttpToolsProtocol(asyncio.Protocol):
self.server: tuple[str, int] | None = None
self.client: tuple[str, int] | None = None
self.scheme: Literal["http", "https"] | None = None
self.pipeline: Deque[tuple[RequestResponseCycle, ASGI3Application]] = deque()
self.pipeline: deque[tuple[RequestResponseCycle, ASGI3Application]] = deque()
# Per-request state
self.scope: HTTPScope = None # type: ignore[assignment]
@@ -154,21 +147,22 @@ class HttpToolsProtocol(asyncio.Protocol):
upgrade = value.lower()
if b"upgrade" in connection:
return upgrade
return None
return None # pragma: full coverage
def _should_upgrade_to_ws(self, upgrade: bytes | None) -> bool:
if upgrade == b"websocket" and self.ws_protocol_class is not None:
return True
if self.config.ws == "auto":
msg = "Unsupported upgrade request."
self.logger.warning(msg)
def _should_upgrade_to_ws(self) -> bool:
if self.ws_protocol_class is None:
return False
return True
def _unsupported_upgrade_warning(self) -> None:
self.logger.warning("Unsupported upgrade request.")
if not self._should_upgrade_to_ws():
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
return False
def _should_upgrade(self) -> bool:
upgrade = self._get_upgrade()
return self._should_upgrade_to_ws(upgrade)
return upgrade == b"websocket" and self._should_upgrade_to_ws()
def data_received(self, data: bytes) -> None:
self._unset_keepalive_if_required()
@@ -181,9 +175,10 @@ class HttpToolsProtocol(asyncio.Protocol):
self.send_400_response(msg)
return
except httptools.HttpParserUpgrade:
upgrade = self._get_upgrade()
if self._should_upgrade_to_ws(upgrade):
if self._should_upgrade():
self.handle_websocket_upgrade()
else:
self._unsupported_upgrade_warning()
def handle_websocket_upgrade(self) -> None:
if self.logger.level <= TRACE_LOG_LEVEL:
@@ -208,7 +203,7 @@ class HttpToolsProtocol(asyncio.Protocol):
def send_400_response(self, msg: str) -> None:
content = [STATUS_LINE[400]]
for name, value in self.server_state.default_headers:
content.extend([name, b": ", value, b"\r\n"])
content.extend([name, b": ", value, b"\r\n"]) # pragma: full coverage
content.extend(
[
b"content-type: text/plain; charset=utf-8\r\n",
@@ -268,8 +263,7 @@ class HttpToolsProtocol(asyncio.Protocol):
# Handle 503 responses when 'limit_concurrency' is exceeded.
if self.limit_concurrency is not None and (
len(self.connections) >= self.limit_concurrency
or len(self.tasks) >= self.limit_concurrency
len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
):
app = service_unavailable
message = "Exceeded concurrency limit."
@@ -302,9 +296,7 @@ class HttpToolsProtocol(asyncio.Protocol):
self.pipeline.appendleft((self.cycle, app))
def on_body(self, body: bytes) -> None:
if (
self.parser.should_upgrade() and self._should_upgrade()
) or self.cycle.response_complete:
if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
return
self.cycle.body += body
if len(self.cycle.body) > HIGH_WATER_LIMIT:
@@ -312,9 +304,7 @@ class HttpToolsProtocol(asyncio.Protocol):
self.cycle.message_event.set()
def on_message_complete(self) -> None:
if (
self.parser.should_upgrade() and self._should_upgrade()
) or self.cycle.response_complete:
if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
return
self.cycle.more_body = False
self.cycle.message_event.set()
@@ -356,13 +346,13 @@ class HttpToolsProtocol(asyncio.Protocol):
"""
Called by the transport when the write buffer exceeds the high water mark.
"""
self.flow.pause_writing()
self.flow.pause_writing() # pragma: full coverage
def resume_writing(self) -> None:
"""
Called by the transport when the write buffer drops below the low water mark.
"""
self.flow.resume_writing()
self.flow.resume_writing() # pragma: full coverage
def timeout_keep_alive_handler(self) -> None:
"""
@@ -376,7 +366,7 @@ class HttpToolsProtocol(asyncio.Protocol):
class RequestResponseCycle:
def __init__(
self,
scope: "HTTPScope",
scope: HTTPScope,
transport: asyncio.Transport,
flow: FlowControl,
logger: logging.Logger,
@@ -414,7 +404,7 @@ class RequestResponseCycle:
self.expected_content_length = 0
# ASGI exception wrapper
async def run_asgi(self, app: "ASGI3Application") -> None:
async def run_asgi(self, app: ASGI3Application) -> None:
try:
result = await app( # type: ignore[func-returns-value]
self.scope, self.receive, self.send
@@ -443,31 +433,28 @@ class RequestResponseCycle:
self.on_response = lambda: None
async def send_500_response(self) -> None:
response_start_event: "HTTPResponseStartEvent" = {
"type": "http.response.start",
"status": 500,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
],
}
await self.send(response_start_event)
response_body_event: "HTTPResponseBodyEvent" = {
"type": "http.response.body",
"body": b"Internal Server Error",
"more_body": False,
}
await self.send(response_body_event)
await self.send(
{
"type": "http.response.start",
"status": 500,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"content-length", b"21"),
(b"connection", b"close"),
],
}
)
await self.send({"type": "http.response.body", "body": b"Internal Server Error", "more_body": False})
# ASGI interface
async def send(self, message: "ASGISendEvent") -> None:
async def send(self, message: ASGISendEvent) -> None:
message_type = message["type"]
if self.flow.write_paused and not self.disconnected:
await self.flow.drain()
await self.flow.drain() # pragma: full coverage
if self.disconnected:
return
return # pragma: full coverage
if not self.response_started:
# Sending response status line and headers
@@ -500,7 +487,7 @@ class RequestResponseCycle:
for name, value in headers:
if HEADER_RE.search(name):
raise RuntimeError("Invalid HTTP header name.")
raise RuntimeError("Invalid HTTP header name.") # pragma: full coverage
if HEADER_VALUE_RE.search(value):
raise RuntimeError("Invalid HTTP header value.")
@@ -515,11 +502,7 @@ class RequestResponseCycle:
self.keep_alive = False
content.extend([name, b": ", value, b"\r\n"])
if (
self.chunked_encoding is None
and self.scope["method"] != "HEAD"
and status_code not in (204, 304)
):
if self.chunked_encoding is None and self.scope["method"] != "HEAD" and status_code not in (204, 304):
# Neither content-length nor transfer-encoding specified
self.chunked_encoding = True
content.append(b"transfer-encoding: chunked\r\n")
@@ -570,7 +553,7 @@ class RequestResponseCycle:
msg = "Unexpected ASGI message '%s' sent, after response already completed."
raise RuntimeError(msg % message_type)
async def receive(self) -> "ASGIReceiveEvent":
async def receive(self) -> ASGIReceiveEvent:
if self.waiting_for_100_continue and not self.transport.is_closing():
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
self.waiting_for_100_continue = False
@@ -580,15 +563,8 @@ class RequestResponseCycle:
await self.message_event.wait()
self.message_event.clear()
message: HTTPDisconnectEvent | HTTPRequestEvent
if self.disconnected or self.response_complete:
message = {"type": "http.disconnect"}
else:
message = {
"type": "http.request",
"body": self.body,
"more_body": self.more_body,
}
self.body = b""
return {"type": "http.disconnect"}
message: HTTPRequestEvent = {"type": "http.request", "body": self.body, "more_body": self.more_body}
self.body = b""
return message

View File

@@ -6,8 +6,7 @@ import urllib.parse
from uvicorn._types import WWWScope
class ClientDisconnected(IOError):
...
class ClientDisconnected(OSError): ...
def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
@@ -53,7 +52,5 @@ def get_client_addr(scope: WWWScope) -> str:
def get_path_with_query_string(scope: WWWScope) -> str:
path_with_query_string = urllib.parse.quote(scope["path"])
if scope["query_string"]:
path_with_query_string = "{}?{}".format(
path_with_query_string, scope["query_string"].decode("ascii")
)
path_with_query_string = "{}?{}".format(path_with_query_string, scope["query_string"].decode("ascii"))
return path_with_query_string

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import asyncio
import typing
AutoWebSocketsProtocol: typing.Optional[typing.Callable[..., asyncio.Protocol]]
AutoWebSocketsProtocol: typing.Callable[..., asyncio.Protocol] | None
try:
import websockets # noqa
except ImportError: # pragma: no cover

View File

@@ -3,27 +3,22 @@ from __future__ import annotations
import asyncio
import http
import logging
from typing import (
Any,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
cast,
)
from collections.abc import Sequence
from typing import Any, Literal, Optional, cast
from urllib.parse import unquote
import websockets
import websockets.legacy.handshake
from websockets.datastructures import Headers
from websockets.exceptions import ConnectionClosed
from websockets.extensions.base import ServerExtensionFactory
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
from websockets.legacy.server import HTTPResponse
from websockets.server import WebSocketServerProtocol
from websockets.typing import Subprotocol
from uvicorn._types import (
ASGI3Application,
ASGISendEvent,
WebSocketAcceptEvent,
WebSocketCloseEvent,
@@ -61,7 +56,8 @@ class Server:
class WebSocketProtocol(WebSocketServerProtocol):
extra_headers: List[Tuple[str, str]]
extra_headers: list[tuple[str, str]]
logger: logging.Logger | logging.LoggerAdapter[Any]
def __init__(
self,
@@ -74,7 +70,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
config.load()
self.config = config
self.app = config.loaded_app
self.app = cast(ASGI3Application, config.loaded_app)
self.loop = _loop or asyncio.get_event_loop()
self.root_path = config.root_path
self.app_state = app_state
@@ -101,7 +97,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.ws_server: Server = Server() # type: ignore[assignment]
extensions = []
extensions: list[ServerExtensionFactory] = []
if self.config.ws_per_message_deflate:
extensions.append(ServerPerMessageDeflateFactory())
@@ -117,8 +113,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
)
self.server_header = None
self.extra_headers = [
(name.decode("latin-1"), value.decode("latin-1"))
for name, value in server_state.default_headers
(name.decode("latin-1"), value.decode("latin-1")) for name, value in server_state.default_headers
]
def connection_made( # type: ignore[override]
@@ -136,16 +131,14 @@ class WebSocketProtocol(WebSocketServerProtocol):
super().connection_made(transport)
def connection_lost(self, exc: Optional[Exception]) -> None:
def connection_lost(self, exc: Exception | None) -> None:
self.connections.remove(self)
if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
self.lost_connection_before_handshake = (
not self.handshake_completed_event.is_set()
)
self.lost_connection_before_handshake = not self.handshake_completed_event.is_set()
self.handshake_completed_event.set()
super().connection_lost(exc)
if exc is None:
@@ -159,12 +152,10 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.send_500_response()
self.transport.close()
def on_task_complete(self, task: asyncio.Task) -> None:
def on_task_complete(self, task: asyncio.Task[None]) -> None:
self.tasks.discard(task)
async def process_request(
self, path: str, headers: Headers
) -> Optional[HTTPResponse]:
async def process_request(self, path: str, request_headers: Headers) -> HTTPResponse | None:
"""
This hook is called to determine if the websocket should return
an HTTP response and close.
@@ -175,15 +166,15 @@ class WebSocketProtocol(WebSocketServerProtocol):
"""
path_portion, _, query_string = path.partition("?")
websockets.legacy.handshake.check_request(headers)
websockets.legacy.handshake.check_request(request_headers)
subprotocols = []
for header in headers.get_all("Sec-WebSocket-Protocol"):
subprotocols: list[str] = []
for header in request_headers.get_all("Sec-WebSocket-Protocol"):
subprotocols.extend([token.strip() for token in header.split(",")])
asgi_headers = [
(name.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
for name, value in headers.raw_items()
for name, value in request_headers.raw_items()
]
path = unquote(path_portion)
full_path = self.root_path + path
@@ -212,8 +203,8 @@ class WebSocketProtocol(WebSocketServerProtocol):
return self.initial_response
def process_subprotocol(
self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]]
) -> Optional[Subprotocol]:
self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None
) -> Subprotocol | None:
"""
We override the standard 'process_subprotocol' behavior here so that
we return whatever subprotocol is sent in the 'accept' message.
@@ -223,8 +214,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
def send_500_response(self) -> None:
msg = b"Internal Server Error"
content = [
b"HTTP/1.1 500 Internal Server Error\r\n"
b"content-type: text/plain; charset=utf-8\r\n",
b"HTTP/1.1 500 Internal Server Error\r\n" b"content-type: text/plain; charset=utf-8\r\n",
b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n",
b"connection: close\r\n",
b"\r\n",
@@ -235,9 +225,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
# itself (see https://github.com/encode/uvicorn/issues/920)
self.handshake_started_event.set()
async def ws_handler( # type: ignore[override]
self, protocol: WebSocketServerProtocol, path: str
) -> Any:
async def ws_handler(self, protocol: WebSocketServerProtocol, path: str) -> Any: # type: ignore[override]
"""
This is the main handler function for the 'websockets' implementation
to call into. We just wait for close then return, and instead allow
@@ -252,14 +240,13 @@ class WebSocketProtocol(WebSocketServerProtocol):
termination states.
"""
try:
result = await self.app(self.scope, self.asgi_receive, self.asgi_send)
except ClientDisconnected:
result = await self.app(self.scope, self.asgi_receive, self.asgi_send) # type: ignore[func-returns-value]
except ClientDisconnected: # pragma: full coverage
self.closed_event.set()
self.transport.close()
except BaseException as exc:
except BaseException:
self.closed_event.set()
msg = "Exception in ASGI application\n"
self.logger.error(msg, exc_info=exc)
self.logger.exception("Exception in ASGI application\n")
if not self.handshake_started_event.is_set():
self.send_500_response()
else:
@@ -268,17 +255,15 @@ class WebSocketProtocol(WebSocketServerProtocol):
else:
self.closed_event.set()
if not self.handshake_started_event.is_set():
msg = "ASGI callable returned without sending handshake."
self.logger.error(msg)
self.logger.error("ASGI callable returned without sending handshake.")
self.send_500_response()
self.transport.close()
elif result is not None:
msg = "ASGI callable should return None, but returned '%s'."
self.logger.error(msg, result)
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
await self.handshake_completed_event.wait()
self.transport.close()
async def asgi_send(self, message: "ASGISendEvent") -> None:
async def asgi_send(self, message: ASGISendEvent) -> None:
message_type = message["type"]
if not self.handshake_started_event.is_set():
@@ -290,9 +275,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
get_path_with_query_string(self.scope),
)
self.initial_response = None
self.accepted_subprotocol = cast(
Optional[Subprotocol], message.get("subprotocol")
)
self.accepted_subprotocol = cast(Optional[Subprotocol], message.get("subprotocol"))
if "headers" in message:
self.extra_headers.extend(
# ASGI spec requires bytes
@@ -324,8 +307,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
# websockets requires the status to be an enum. look it up.
status = http.HTTPStatus(message["status"])
headers = [
(name.decode("latin-1"), value.decode("latin-1"))
for name, value in message.get("headers", [])
(name.decode("latin-1"), value.decode("latin-1")) for name, value in message.get("headers", [])
]
self.initial_response = (status, headers, b"")
self.handshake_started_event.set()
@@ -356,10 +338,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.closed_event.set()
else:
msg = (
"Expected ASGI message 'websocket.send' or 'websocket.close',"
" but got '%s'."
)
msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'."
raise RuntimeError(msg % message_type)
except ConnectionClosed as exc:
raise ClientDisconnected from exc
@@ -372,24 +351,14 @@ class WebSocketProtocol(WebSocketServerProtocol):
if not message.get("more_body", False):
self.closed_event.set()
else:
msg = (
"Expected ASGI message 'websocket.http.response.body' "
"but got '%s'."
)
msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'."
raise RuntimeError(msg % message_type)
else:
msg = (
"Unexpected ASGI message '%s', after sending 'websocket.close' "
"or response already completed."
)
msg = "Unexpected ASGI message '%s', after sending 'websocket.close' " "or response already completed."
raise RuntimeError(msg % message_type)
async def asgi_receive(
self,
) -> Union[
"WebSocketDisconnectEvent", "WebSocketConnectEvent", "WebSocketReceiveEvent"
]:
async def asgi_receive(self) -> WebSocketDisconnectEvent | WebSocketConnectEvent | WebSocketReceiveEvent:
if not self.connect_sent:
self.connect_sent = True
return {"type": "websocket.connect"}
@@ -406,11 +375,11 @@ class WebSocketProtocol(WebSocketServerProtocol):
try:
data = await self.recv()
except ConnectionClosed as exc:
except ConnectionClosed:
self.closed_event.set()
if self.ws_server.closing:
return {"type": "websocket.disconnect", "code": 1012}
return {"type": "websocket.disconnect", "code": exc.code}
return {"type": "websocket.disconnect", "code": self.close_code or 1005, "reason": self.close_reason}
if isinstance(data, str):
return {"type": "websocket.receive", "text": data}

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
import logging
import typing
from typing import Literal
from typing import Literal, cast
from urllib.parse import unquote
import wsproto
@@ -13,6 +13,7 @@ from wsproto.extensions import Extension, PerMessageDeflate
from wsproto.utilities import LocalProtocolError, RemoteProtocolError
from uvicorn._types import (
ASGI3Application,
ASGISendEvent,
WebSocketAcceptEvent,
WebSocketCloseEvent,
@@ -43,10 +44,10 @@ class WSProtocol(asyncio.Protocol):
_loop: asyncio.AbstractEventLoop | None = None,
) -> None:
if not config.loaded:
config.load()
config.load() # pragma: full coverage
self.config = config
self.app = config.loaded_app
self.app = cast(ASGI3Application, config.loaded_app)
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.root_path = config.root_path
@@ -139,13 +140,13 @@ class WSProtocol(asyncio.Protocol):
"""
Called by the transport when the write buffer exceeds the high water mark.
"""
self.writable.clear()
self.writable.clear() # pragma: full coverage
def resume_writing(self) -> None:
"""
Called by the transport when the write buffer drops below the low water mark.
"""
self.writable.set()
self.writable.set() # pragma: full coverage
def shutdown(self) -> None:
if self.handshake_complete:
@@ -156,7 +157,7 @@ class WSProtocol(asyncio.Protocol):
self.send_500_response()
self.transport.close()
def on_task_complete(self, task: asyncio.Task) -> None:
def on_task_complete(self, task: asyncio.Task[None]) -> None:
self.tasks.discard(task)
# Event handlers
@@ -168,7 +169,7 @@ class WSProtocol(asyncio.Protocol):
path = unquote(raw_path)
full_path = self.root_path + path
full_raw_path = self.root_path.encode("ascii") + raw_path.encode("ascii")
self.scope: "WebSocketScope" = {
self.scope: WebSocketScope = {
"type": "websocket",
"asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
"http_version": "1.1",
@@ -211,7 +212,7 @@ class WSProtocol(asyncio.Protocol):
def handle_close(self, event: events.CloseConnection) -> None:
if self.conn.state == ConnectionState.REMOTE_CLOSING:
self.transport.write(self.conn.send(event.response()))
self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code})
self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code, "reason": event.reason})
self.transport.close()
def handle_ping(self, event: events.Ping) -> None:
@@ -220,38 +221,31 @@ class WSProtocol(asyncio.Protocol):
def send_500_response(self) -> None:
if self.response_started or self.handshake_complete:
return # we cannot send responses anymore
headers = [
headers: list[tuple[bytes, bytes]] = [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
(b"content-length", b"21"),
]
output = self.conn.send(
wsproto.events.RejectConnection(
status_code=500, headers=headers, has_body=True
)
)
output += self.conn.send(
wsproto.events.RejectData(data=b"Internal Server Error")
)
output = self.conn.send(wsproto.events.RejectConnection(status_code=500, headers=headers, has_body=True))
output += self.conn.send(wsproto.events.RejectData(data=b"Internal Server Error"))
self.transport.write(output)
async def run_asgi(self) -> None:
try:
result = await self.app(self.scope, self.receive, self.send)
result = await self.app(self.scope, self.receive, self.send) # type: ignore[func-returns-value]
except ClientDisconnected:
self.transport.close()
self.transport.close() # pragma: full coverage
except BaseException:
self.logger.exception("Exception in ASGI application\n")
self.send_500_response()
self.transport.close()
else:
if not self.handshake_complete:
msg = "ASGI callable returned without completing handshake."
self.logger.error(msg)
self.logger.error("ASGI callable returned without completing handshake.")
self.send_500_response()
self.transport.close()
elif result is not None:
msg = "ASGI callable should return None, but returned '%s'."
self.logger.error(msg, result)
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
self.transport.close()
async def send(self, message: ASGISendEvent) -> None:
@@ -269,7 +263,7 @@ class WSProtocol(asyncio.Protocol):
)
subprotocol = message.get("subprotocol")
extra_headers = self.default_headers + list(message.get("headers", []))
extensions: typing.List[Extension] = []
extensions: list[Extension] = []
if self.config.ws_per_message_deflate:
extensions.append(PerMessageDeflate())
if not self.transport.is_closing():
@@ -343,21 +337,14 @@ class WSProtocol(asyncio.Protocol):
self.close_sent = True
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
self.queue.put_nowait(
{"type": "websocket.disconnect", "code": code}
)
output = self.conn.send(
wsproto.events.CloseConnection(code=code, reason=reason)
)
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
output = self.conn.send(wsproto.events.CloseConnection(code=code, reason=reason))
if not self.transport.is_closing():
self.transport.write(output)
self.transport.close()
else:
msg = (
"Expected ASGI message 'websocket.send' or 'websocket.close',"
" but got '%s'."
)
msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'."
raise RuntimeError(msg % message_type)
except LocalProtocolError as exc:
raise ClientDisconnected from exc
@@ -365,24 +352,17 @@ class WSProtocol(asyncio.Protocol):
if message_type == "websocket.http.response.body":
message = typing.cast("WebSocketResponseBodyEvent", message)
body_finished = not message.get("more_body", False)
reject_data = events.RejectData(
data=message["body"], body_finished=body_finished
)
reject_data = events.RejectData(data=message["body"], body_finished=body_finished)
output = self.conn.send(reject_data)
self.transport.write(output)
if body_finished:
self.queue.put_nowait(
{"type": "websocket.disconnect", "code": 1006}
)
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.close_sent = True
self.transport.close()
else:
msg = (
"Expected ASGI message 'websocket.http.response.body' "
"but got '%s'."
)
msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'."
raise RuntimeError(msg % message_type)
else: