mirror of
https://gitlab.com/MoonTestUse1/AdministrationItDepartmens.git
synced 2025-08-14 00:25:46 +02:00
Initial commit
This commit is contained in:
176
venv/Lib/site-packages/starlette/middleware/cors.py
Normal file
176
venv/Lib/site-packages/starlette/middleware/cors.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import functools
|
||||
import re
|
||||
import typing
|
||||
|
||||
from starlette.datastructures import Headers, MutableHeaders
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT")
|
||||
SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"}
|
||||
|
||||
|
||||
class CORSMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allow_origins: typing.Sequence[str] = (),
|
||||
allow_methods: typing.Sequence[str] = ("GET",),
|
||||
allow_headers: typing.Sequence[str] = (),
|
||||
allow_credentials: bool = False,
|
||||
allow_origin_regex: typing.Optional[str] = None,
|
||||
expose_headers: typing.Sequence[str] = (),
|
||||
max_age: int = 600,
|
||||
) -> None:
|
||||
if "*" in allow_methods:
|
||||
allow_methods = ALL_METHODS
|
||||
|
||||
compiled_allow_origin_regex = None
|
||||
if allow_origin_regex is not None:
|
||||
compiled_allow_origin_regex = re.compile(allow_origin_regex)
|
||||
|
||||
allow_all_origins = "*" in allow_origins
|
||||
allow_all_headers = "*" in allow_headers
|
||||
preflight_explicit_allow_origin = not allow_all_origins or allow_credentials
|
||||
|
||||
simple_headers = {}
|
||||
if allow_all_origins:
|
||||
simple_headers["Access-Control-Allow-Origin"] = "*"
|
||||
if allow_credentials:
|
||||
simple_headers["Access-Control-Allow-Credentials"] = "true"
|
||||
if expose_headers:
|
||||
simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)
|
||||
|
||||
preflight_headers = {}
|
||||
if preflight_explicit_allow_origin:
|
||||
# The origin value will be set in preflight_response() if it is allowed.
|
||||
preflight_headers["Vary"] = "Origin"
|
||||
else:
|
||||
preflight_headers["Access-Control-Allow-Origin"] = "*"
|
||||
preflight_headers.update(
|
||||
{
|
||||
"Access-Control-Allow-Methods": ", ".join(allow_methods),
|
||||
"Access-Control-Max-Age": str(max_age),
|
||||
}
|
||||
)
|
||||
allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers))
|
||||
if allow_headers and not allow_all_headers:
|
||||
preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
|
||||
if allow_credentials:
|
||||
preflight_headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
||||
self.app = app
|
||||
self.allow_origins = allow_origins
|
||||
self.allow_methods = allow_methods
|
||||
self.allow_headers = [h.lower() for h in allow_headers]
|
||||
self.allow_all_origins = allow_all_origins
|
||||
self.allow_all_headers = allow_all_headers
|
||||
self.preflight_explicit_allow_origin = preflight_explicit_allow_origin
|
||||
self.allow_origin_regex = compiled_allow_origin_regex
|
||||
self.simple_headers = simple_headers
|
||||
self.preflight_headers = preflight_headers
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http": # pragma: no cover
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
method = scope["method"]
|
||||
headers = Headers(scope=scope)
|
||||
origin = headers.get("origin")
|
||||
|
||||
if origin is None:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
if method == "OPTIONS" and "access-control-request-method" in headers:
|
||||
response = self.preflight_response(request_headers=headers)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
await self.simple_response(scope, receive, send, request_headers=headers)
|
||||
|
||||
def is_allowed_origin(self, origin: str) -> bool:
|
||||
if self.allow_all_origins:
|
||||
return True
|
||||
|
||||
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(
|
||||
origin
|
||||
):
|
||||
return True
|
||||
|
||||
return origin in self.allow_origins
|
||||
|
||||
def preflight_response(self, request_headers: Headers) -> Response:
|
||||
requested_origin = request_headers["origin"]
|
||||
requested_method = request_headers["access-control-request-method"]
|
||||
requested_headers = request_headers.get("access-control-request-headers")
|
||||
|
||||
headers = dict(self.preflight_headers)
|
||||
failures = []
|
||||
|
||||
if self.is_allowed_origin(origin=requested_origin):
|
||||
if self.preflight_explicit_allow_origin:
|
||||
# The "else" case is already accounted for in self.preflight_headers
|
||||
# and the value would be "*".
|
||||
headers["Access-Control-Allow-Origin"] = requested_origin
|
||||
else:
|
||||
failures.append("origin")
|
||||
|
||||
if requested_method not in self.allow_methods:
|
||||
failures.append("method")
|
||||
|
||||
# If we allow all headers, then we have to mirror back any requested
|
||||
# headers in the response.
|
||||
if self.allow_all_headers and requested_headers is not None:
|
||||
headers["Access-Control-Allow-Headers"] = requested_headers
|
||||
elif requested_headers is not None:
|
||||
for header in [h.lower() for h in requested_headers.split(",")]:
|
||||
if header.strip() not in self.allow_headers:
|
||||
failures.append("headers")
|
||||
break
|
||||
|
||||
# We don't strictly need to use 400 responses here, since its up to
|
||||
# the browser to enforce the CORS policy, but its more informative
|
||||
# if we do.
|
||||
if failures:
|
||||
failure_text = "Disallowed CORS " + ", ".join(failures)
|
||||
return PlainTextResponse(failure_text, status_code=400, headers=headers)
|
||||
|
||||
return PlainTextResponse("OK", status_code=200, headers=headers)
|
||||
|
||||
async def simple_response(
|
||||
self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
|
||||
) -> None:
|
||||
send = functools.partial(self.send, send=send, request_headers=request_headers)
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
async def send(
|
||||
self, message: Message, send: Send, request_headers: Headers
|
||||
) -> None:
|
||||
if message["type"] != "http.response.start":
|
||||
await send(message)
|
||||
return
|
||||
|
||||
message.setdefault("headers", [])
|
||||
headers = MutableHeaders(scope=message)
|
||||
headers.update(self.simple_headers)
|
||||
origin = request_headers["Origin"]
|
||||
has_cookie = "cookie" in request_headers
|
||||
|
||||
# If request includes any cookie headers, then we must respond
|
||||
# with the specific origin instead of '*'.
|
||||
if self.allow_all_origins and has_cookie:
|
||||
self.allow_explicit_origin(headers, origin)
|
||||
|
||||
# If we only allow specific origins, then we have to mirror back
|
||||
# the Origin header in the response.
|
||||
elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
|
||||
self.allow_explicit_origin(headers, origin)
|
||||
|
||||
await send(message)
|
||||
|
||||
@staticmethod
|
||||
def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None:
|
||||
headers["Access-Control-Allow-Origin"] = origin
|
||||
headers.add_vary_header("Origin")
|
Reference in New Issue
Block a user