1
0
mirror of https://gitlab.com/MoonTestUse1/AdministrationItDepartmens.git synced 2025-08-14 00:25:46 +02:00

Проверка 09.02.2025

This commit is contained in:
MoonTestUse1
2025-02-09 01:11:49 +06:00
parent ce52f8a23a
commit 0aa3ef8fc2
5827 changed files with 14316 additions and 1906434 deletions

View File

@@ -1,5 +1,5 @@
# util/__init__.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -65,6 +65,7 @@ from .compat import osx as osx
from .compat import py310 as py310
from .compat import py311 as py311
from .compat import py312 as py312
from .compat import py313 as py313
from .compat import py38 as py38
from .compat import py39 as py39
from .compat import pypy as pypy

View File

@@ -1,5 +1,5 @@
# util/_collections.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -16,6 +16,7 @@ import typing
from typing import Any
from typing import Callable
from typing import cast
from typing import Container
from typing import Dict
from typing import FrozenSet
from typing import Generic
@@ -79,8 +80,8 @@ def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]:
Example::
>>> a = ['__tablename__', 'id', 'x', 'created_at']
>>> b = ['id', 'name', 'data', 'y', 'created_at']
>>> a = ["__tablename__", "id", "x", "created_at"]
>>> b = ["id", "name", "data", "y", "created_at"]
>>> merge_lists_w_ordering(a, b)
['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at']
@@ -425,15 +426,14 @@ def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]:
return list(x)
def has_intersection(set_, iterable):
def has_intersection(set_: Container[Any], iterable: Iterable[Any]) -> bool:
r"""return True if any items of set\_ are present in iterable.
Goes through special effort to ensure __hash__ is not called
on items in iterable that don't support it.
"""
# TODO: optimize, write in C, etc.
return bool(set_.intersection([i for i in iterable if i.__hash__]))
return any(i in set_ for i in iterable if i.__hash__)
def to_set(x):

View File

@@ -1,5 +1,5 @@
# util/_concurrency_py3k.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -19,10 +19,14 @@ from typing import Coroutine
from typing import Optional
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from .langhelpers import memoized_property
from .. import exc
from ..util import py311
from ..util.typing import Literal
from ..util.typing import Protocol
from ..util.typing import Self
from ..util.typing import TypeGuard
_T = TypeVar("_T")
@@ -70,9 +74,10 @@ def is_exit_exception(e: BaseException) -> bool:
class _AsyncIoGreenlet(greenlet):
dead: bool
__sqlalchemy_greenlet_provider__ = True
def __init__(self, fn: Callable[..., Any], driver: greenlet):
greenlet.__init__(self, fn, driver)
self.driver = driver
if _has_gr_context:
self.gr_context = driver.gr_context
@@ -98,7 +103,7 @@ def _safe_cancel_awaitable(awaitable: Awaitable[Any]) -> None:
def in_greenlet() -> bool:
current = getcurrent()
return isinstance(current, _AsyncIoGreenlet)
return getattr(current, "__sqlalchemy_greenlet_provider__", False)
def await_only(awaitable: Awaitable[_T]) -> _T:
@@ -112,7 +117,7 @@ def await_only(awaitable: Awaitable[_T]) -> _T:
"""
# this is called in the context greenlet while running fn
current = getcurrent()
if not isinstance(current, _AsyncIoGreenlet):
if not getattr(current, "__sqlalchemy_greenlet_provider__", False):
_safe_cancel_awaitable(awaitable)
raise exc.MissingGreenlet(
@@ -124,7 +129,7 @@ def await_only(awaitable: Awaitable[_T]) -> _T:
# a coroutine to run. Once the awaitable is done, the driver greenlet
# switches back to this greenlet with the result of awaitable that is
# then returned to the caller (or raised as error)
return current.driver.switch(awaitable) # type: ignore[no-any-return]
return current.parent.switch(awaitable) # type: ignore[no-any-return,attr-defined] # noqa: E501
def await_fallback(awaitable: Awaitable[_T]) -> _T:
@@ -144,7 +149,7 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T:
# this is called in the context greenlet while running fn
current = getcurrent()
if not isinstance(current, _AsyncIoGreenlet):
if not getattr(current, "__sqlalchemy_greenlet_provider__", False):
loop = get_event_loop()
if loop.is_running():
_safe_cancel_awaitable(awaitable)
@@ -156,7 +161,7 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T:
)
return loop.run_until_complete(awaitable)
return current.driver.switch(awaitable) # type: ignore[no-any-return]
return current.parent.switch(awaitable) # type: ignore[no-any-return,attr-defined] # noqa: E501
async def greenlet_spawn(
@@ -182,24 +187,21 @@ async def greenlet_spawn(
# coroutine to wait. If the context is dead the function has
# returned, and its result can be returned.
switch_occurred = False
try:
result = context.switch(*args, **kwargs)
while not context.dead:
switch_occurred = True
try:
# wait for a coroutine from await_only and then return its
# result back to it.
value = await result
except BaseException:
# this allows an exception to be raised within
# the moderated greenlet so that it can continue
# its expected flow.
result = context.throw(*sys.exc_info())
else:
result = context.switch(value)
finally:
# clean up to avoid cycle resolution by gc
del context.driver
result = context.switch(*args, **kwargs)
while not context.dead:
switch_occurred = True
try:
# wait for a coroutine from await_only and then return its
# result back to it.
value = await result
except BaseException:
# this allows an exception to be raised within
# the moderated greenlet so that it can continue
# its expected flow.
result = context.throw(*sys.exc_info())
else:
result = context.switch(value)
if _require_await and not switch_occurred:
raise exc.AwaitRequired(
"The current operation required an async execution but none was "
@@ -225,34 +227,6 @@ class AsyncAdaptedLock:
self.mutex.release()
def _util_async_run_coroutine_function(
fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any
) -> Any:
"""for test suite/ util only"""
loop = get_event_loop()
if loop.is_running():
raise Exception(
"for async run coroutine we expect that no greenlet or event "
"loop is running when we start out"
)
return loop.run_until_complete(fn(*args, **kwargs))
def _util_async_run(
fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any
) -> Any:
"""for test suite/ util only"""
loop = get_event_loop()
if not loop.is_running():
return loop.run_until_complete(greenlet_spawn(fn, *args, **kwargs))
else:
# allow for a wrapped test function to call another
assert isinstance(getcurrent(), _AsyncIoGreenlet)
return fn(*args, **kwargs)
def get_event_loop() -> asyncio.AbstractEventLoop:
"""vendor asyncio.get_event_loop() for python 3.7 and above.
@@ -265,3 +239,50 @@ def get_event_loop() -> asyncio.AbstractEventLoop:
# avoid "During handling of the above exception, another exception..."
pass
return asyncio.get_event_loop_policy().get_event_loop()
if not TYPE_CHECKING and py311:
_Runner = asyncio.Runner
else:
class _Runner:
"""Runner implementation for test only"""
_loop: Union[None, asyncio.AbstractEventLoop, Literal[False]]
def __init__(self) -> None:
self._loop = None
def __enter__(self) -> Self:
self._lazy_init()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
def close(self) -> None:
if self._loop:
try:
self._loop.run_until_complete(
self._loop.shutdown_asyncgens()
)
finally:
self._loop.close()
self._loop = False
def get_loop(self) -> asyncio.AbstractEventLoop:
"""Return embedded event loop."""
self._lazy_init()
assert self._loop
return self._loop
def run(self, coro: Coroutine[Any, Any, _T]) -> _T:
self._lazy_init()
assert self._loop
return self._loop.run_until_complete(coro)
def _lazy_init(self) -> None:
if self._loop is False:
raise RuntimeError("Runner is closed")
if self._loop is None:
self._loop = asyncio.new_event_loop()

View File

@@ -1,5 +1,5 @@
# util/_has_cy.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# util/_py_collections.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# util/compat.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -32,6 +32,7 @@ from typing import Type
from typing import TypeVar
py313 = sys.version_info >= (3, 13)
py312 = sys.version_info >= (3, 12)
py311 = sys.version_info >= (3, 11)
py310 = sys.version_info >= (3, 10)
@@ -58,7 +59,7 @@ class FullArgSpec(typing.NamedTuple):
varkw: Optional[str]
defaults: Optional[Tuple[Any, ...]]
kwonlyargs: List[str]
kwonlydefaults: Dict[str, Any]
kwonlydefaults: Optional[Dict[str, Any]]
annotations: Dict[str, Any]

View File

@@ -1,5 +1,5 @@
# util/concurrency.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -10,11 +10,15 @@ from __future__ import annotations
import asyncio # noqa
import typing
from typing import Any
from typing import Callable
from typing import Coroutine
from typing import TypeVar
have_greenlet = False
greenlet_error = None
try:
import greenlet # type: ignore # noqa: F401
import greenlet # type: ignore[import-untyped,unused-ignore] # noqa: F401,E501
except ImportError as e:
greenlet_error = str(e)
pass
@@ -26,12 +30,43 @@ else:
from ._concurrency_py3k import greenlet_spawn as greenlet_spawn
from ._concurrency_py3k import is_exit_exception as is_exit_exception
from ._concurrency_py3k import AsyncAdaptedLock as AsyncAdaptedLock
from ._concurrency_py3k import (
_util_async_run as _util_async_run,
) # noqa: F401
from ._concurrency_py3k import (
_util_async_run_coroutine_function as _util_async_run_coroutine_function, # noqa: F401, E501
)
from ._concurrency_py3k import _Runner
_T = TypeVar("_T")
class _AsyncUtil:
"""Asyncio util for test suite/ util only"""
def __init__(self) -> None:
if have_greenlet:
self.runner = _Runner()
def run(
self,
fn: Callable[..., Coroutine[Any, Any, _T]],
*args: Any,
**kwargs: Any,
) -> _T:
"""Run coroutine on the loop"""
return self.runner.run(fn(*args, **kwargs))
def run_in_greenlet(
self, fn: Callable[..., _T], *args: Any, **kwargs: Any
) -> _T:
"""Run sync function in greenlet. Support nested calls"""
if have_greenlet:
if self.runner.get_loop().is_running():
return fn(*args, **kwargs)
else:
return self.runner.run(greenlet_spawn(fn, *args, **kwargs))
else:
return fn(*args, **kwargs)
def close(self) -> None:
if have_greenlet:
self.runner.close()
if not typing.TYPE_CHECKING and not have_greenlet:

View File

@@ -1,5 +1,5 @@
# util/deprecations.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -205,10 +205,10 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_F], _F]:
weak_identity_map=(
"0.7",
"the :paramref:`.Session.weak_identity_map parameter "
"is deprecated."
"is deprecated.",
)
)
def some_function(**kwargs): ...
"""

View File

@@ -1,5 +1,5 @@
# util/langhelpers.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -174,10 +174,11 @@ def string_or_unprintable(element: Any) -> str:
return "unprintable element %r" % element
def clsname_as_plain_name(cls: Type[Any]) -> str:
return " ".join(
n.lower() for n in re.findall(r"([A-Z][a-z]+|SQL)", cls.__name__)
)
def clsname_as_plain_name(
cls: Type[Any], use_name: Optional[str] = None
) -> str:
name = use_name or cls.__name__
return " ".join(n.lower() for n in re.findall(r"([A-Z][a-z]+|SQL)", name))
def method_is_overridden(
@@ -307,10 +308,10 @@ def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]:
)
decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
decorated.__wrapped__ = fn # type: ignore
return cast(_Fn, update_wrapper(decorated, fn))
decorated.__wrapped__ = fn # type: ignore[attr-defined]
return update_wrapper(decorated, fn) # type: ignore[return-value]
return update_wrapper(decorate, target)
return update_wrapper(decorate, target) # type: ignore[return-value]
def _update_argspec_defaults_into_env(spec, env):
@@ -661,7 +662,9 @@ def format_argspec_init(method, grouped=True):
"""format_argspec_plus with considerations for typical __init__ methods
Wraps format_argspec_plus with error handling strategies for typical
__init__ cases::
__init__ cases:
.. sourcecode:: text
object.__init__ -> (self)
other unreflectable (usually C) -> (self, *args, **kwargs)
@@ -716,7 +719,9 @@ def create_proxy_methods(
def getargspec_init(method):
"""inspect.getargspec with considerations for typical __init__ methods
Wraps inspect.getargspec with error handling for typical __init__ cases::
Wraps inspect.getargspec with error handling for typical __init__ cases:
.. sourcecode:: text
object.__init__ -> (self)
other unreflectable (usually C) -> (self, *args, **kwargs)
@@ -1590,9 +1595,9 @@ class hybridmethod(Generic[_T]):
class symbol(int):
"""A constant symbol.
>>> symbol('foo') is symbol('foo')
>>> symbol("foo") is symbol("foo")
True
>>> symbol('foo')
>>> symbol("foo")
<symbol 'foo>
A slight refinement of the MAGICCOOKIE=object() pattern. The primary
@@ -1658,6 +1663,8 @@ class _IntFlagMeta(type):
items: List[symbol]
cls._items = items = []
for k, v in dict_.items():
if re.match(r"^__.*__$", k):
continue
if isinstance(v, int):
sym = symbol(k, canonical=v)
elif not k.startswith("_"):
@@ -1957,6 +1964,9 @@ def attrsetter(attrname):
return env["set"]
_dunders = re.compile("^__.+__$")
class TypingOnly:
"""A mixin class that marks a class as 'typing only', meaning it has
absolutely no methods, attributes, or runtime functionality whatsoever.
@@ -1967,15 +1977,9 @@ class TypingOnly:
def __init_subclass__(cls) -> None:
if TypingOnly in cls.__bases__:
remaining = set(cls.__dict__).difference(
{
"__module__",
"__doc__",
"__slots__",
"__orig_bases__",
"__annotations__",
}
)
remaining = {
name for name in cls.__dict__ if not _dunders.match(name)
}
if remaining:
raise AssertionError(
f"Class {cls} directly inherits TypingOnly but has "
@@ -2208,3 +2212,11 @@ def has_compiled_ext(raise_=False):
)
else:
return False
class _Missing(enum.Enum):
Missing = enum.auto()
Missing = _Missing.Missing
MissingOr = Union[_T, Literal[_Missing.Missing]]

View File

@@ -1,5 +1,5 @@
# util/preloaded.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# util/queue.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# util/tool_support.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# util/topological.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# util/typing.py
# Copyright (C) 2022-2024 the SQLAlchemy authors and contributors
# Copyright (C) 2022-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -9,6 +9,7 @@
from __future__ import annotations
import builtins
from collections import deque
import collections.abc as collections_abc
import re
import sys
@@ -54,6 +55,7 @@ if True: # zimports removes the tailing comments
from typing_extensions import TypeGuard as TypeGuard # 3.10
from typing_extensions import Self as Self # 3.11
from typing_extensions import TypeAliasType as TypeAliasType # 3.12
from typing_extensions import Never as Never # 3.11
_T = TypeVar("_T", bound=Any)
_KT = TypeVar("_KT")
@@ -62,6 +64,13 @@ _KT_contra = TypeVar("_KT_contra", contravariant=True)
_VT = TypeVar("_VT")
_VT_co = TypeVar("_VT_co", covariant=True)
if compat.py38:
# typing_extensions.Literal is different from typing.Literal until
# Python 3.10.1
LITERAL_TYPES = frozenset([typing.Literal, Literal])
else:
LITERAL_TYPES = frozenset([Literal])
if compat.py310:
# why they took until py310 to put this in stdlib is beyond me,
@@ -72,16 +81,13 @@ else:
NoneFwd = ForwardRef("None")
typing_get_args = get_args
typing_get_origin = get_origin
_AnnotationScanType = Union[
Type[Any], str, ForwardRef, NewType, TypeAliasType, "GenericProtocol[Any]"
]
class ArgsTypeProcotol(Protocol):
class ArgsTypeProtocol(Protocol):
"""protocol for types that have ``__args__``
there's no public interface for this AFAIK
@@ -188,9 +194,51 @@ def de_stringify_annotation(
)
return _copy_generic_annotation_with(annotation, elements)
return annotation # type: ignore
def fixup_container_fwd_refs(
type_: _AnnotationScanType,
) -> _AnnotationScanType:
"""Correct dict['x', 'y'] into dict[ForwardRef('x'), ForwardRef('y')]
and similar for list, set
"""
if (
is_generic(type_)
and get_origin(type_)
in (
dict,
set,
list,
collections_abc.MutableSet,
collections_abc.MutableMapping,
collections_abc.MutableSequence,
collections_abc.Mapping,
collections_abc.Sequence,
)
# fight, kick and scream to struggle to tell the difference between
# dict[] and typing.Dict[] which DO NOT compare the same and DO NOT
# behave the same yet there is NO WAY to distinguish between which type
# it is using public attributes
and not re.match(
"typing.(?:Dict|List|Set|.*Mapping|.*Sequence|.*Set)", repr(type_)
)
):
# compat with py3.10 and earlier
return get_origin(type_).__class_getitem__( # type: ignore
tuple(
[
ForwardRef(elem) if isinstance(elem, str) else elem
for elem in get_args(type_)
]
)
)
return type_
def _copy_generic_annotation_with(
annotation: GenericProtocol[_T], elements: Tuple[_AnnotationScanType, ...]
) -> Type[_T]:
@@ -281,30 +329,8 @@ def resolve_name_to_real_class_name(name: str, module_name: str) -> str:
return getattr(obj, "__name__", name)
def de_stringify_union_elements(
cls: Type[Any],
annotation: ArgsTypeProcotol,
originating_module: str,
locals_: Mapping[str, Any],
*,
str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
) -> Type[Any]:
return make_union_type(
*[
de_stringify_annotation(
cls,
anno,
originating_module,
{},
str_cleanup_fn=str_cleanup_fn,
)
for anno in annotation.__args__
]
)
def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
return type_ is not None and typing_get_origin(type_) is Annotated
def is_pep593(type_: Optional[Any]) -> bool:
return type_ is not None and get_origin(type_) is Annotated
def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
@@ -313,8 +339,8 @@ def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
)
def is_literal(type_: _AnnotationScanType) -> bool:
return get_origin(type_) is Literal
def is_literal(type_: Any) -> bool:
return get_origin(type_) in LITERAL_TYPES
def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]:
@@ -337,16 +363,62 @@ def flatten_newtype(type_: NewType) -> Type[Any]:
super_type = type_.__supertype__
while is_newtype(super_type):
super_type = super_type.__supertype__
return super_type
return super_type # type: ignore[return-value]
def pep695_values(type_: _AnnotationScanType) -> Set[Any]:
"""Extracts the value from a TypeAliasType, recursively exploring unions
and inner TypeAliasType to flatten them into a single set.
Forward references are not evaluated, so no recursive exploration happens
into them.
"""
_seen = set()
def recursive_value(type_):
if type_ in _seen:
# recursion are not supported (at least it's flagged as
# an error by pyright). Just avoid infinite loop
return type_
_seen.add(type_)
if not is_pep695(type_):
return type_
value = type_.__value__
if not is_union(value):
return value
return [recursive_value(t) for t in value.__args__]
res = recursive_value(type_)
if isinstance(res, list):
types = set()
stack = deque(res)
while stack:
t = stack.popleft()
if isinstance(t, list):
stack.extend(t)
else:
types.add(None if t in {NoneType, NoneFwd} else t)
return types
else:
return {res}
def is_fwd_ref(
type_: _AnnotationScanType, check_generic: bool = False
type_: _AnnotationScanType,
check_generic: bool = False,
check_for_plain_string: bool = False,
) -> TypeGuard[ForwardRef]:
if isinstance(type_, ForwardRef):
if check_for_plain_string and isinstance(type_, str):
return True
elif isinstance(type_, ForwardRef):
return True
elif check_generic and is_generic(type_):
return any(is_fwd_ref(arg, True) for arg in type_.__args__)
return any(
is_fwd_ref(
arg, True, check_for_plain_string=check_for_plain_string
)
for arg in type_.__args__
)
else:
return False
@@ -371,13 +443,31 @@ def de_optionalize_union_types(
"""Given a type, filter out ``Union`` types that include ``NoneType``
to not include the ``NoneType``.
Contains extra logic to work on non-flattened unions, unions that contain
``None`` (seen in py38, 37)
"""
if is_fwd_ref(type_):
return de_optionalize_fwd_ref_union_types(type_)
return _de_optionalize_fwd_ref_union_types(type_, False)
elif is_optional(type_):
typ = set(type_.__args__)
elif is_union(type_) and includes_none(type_):
if compat.py39:
typ = set(type_.__args__)
else:
# py38, 37 - unions are not automatically flattened, can contain
# None rather than NoneType
stack_of_unions = deque([type_])
typ = set()
while stack_of_unions:
u_typ = stack_of_unions.popleft()
for elem in u_typ.__args__:
if is_union(elem):
stack_of_unions.append(elem)
else:
typ.add(elem)
typ.discard(None) # type: ignore
typ.discard(NoneType)
typ.discard(NoneFwd)
@@ -388,9 +478,21 @@ def de_optionalize_union_types(
return type_
def de_optionalize_fwd_ref_union_types(
type_: ForwardRef,
) -> _AnnotationScanType:
@overload
def _de_optionalize_fwd_ref_union_types(
type_: ForwardRef, return_has_none: Literal[True]
) -> bool: ...
@overload
def _de_optionalize_fwd_ref_union_types(
type_: ForwardRef, return_has_none: Literal[False]
) -> _AnnotationScanType: ...
def _de_optionalize_fwd_ref_union_types(
type_: ForwardRef, return_has_none: bool
) -> Union[_AnnotationScanType, bool]:
"""return the non-optional type for Optional[], Union[None, ...], x|None,
etc. without de-stringifying forward refs.
@@ -402,68 +504,78 @@ def de_optionalize_fwd_ref_union_types(
mm = re.match(r"^(.+?)\[(.+)\]$", annotation)
if mm:
if mm.group(1) == "Optional":
return ForwardRef(mm.group(2))
elif mm.group(1) == "Union":
elements = re.split(r",\s*", mm.group(2))
return make_union_type(
*[ForwardRef(elem) for elem in elements if elem != "None"]
)
g1 = mm.group(1).split(".")[-1]
if g1 == "Optional":
return True if return_has_none else ForwardRef(mm.group(2))
elif g1 == "Union":
if "[" in mm.group(2):
# cases like "Union[Dict[str, int], int, None]"
elements: list[str] = []
current: list[str] = []
ignore_comma = 0
for char in mm.group(2):
if char == "[":
ignore_comma += 1
elif char == "]":
ignore_comma -= 1
elif ignore_comma == 0 and char == ",":
elements.append("".join(current).strip())
current.clear()
continue
current.append(char)
else:
elements = re.split(r",\s*", mm.group(2))
parts = [ForwardRef(elem) for elem in elements if elem != "None"]
if return_has_none:
return len(elements) != len(parts)
else:
return make_union_type(*parts) if parts else Never # type: ignore[return-value] # noqa: E501
else:
return type_
return False if return_has_none else type_
pipe_tokens = re.split(r"\s*\|\s*", annotation)
if "None" in pipe_tokens:
return ForwardRef("|".join(p for p in pipe_tokens if p != "None"))
has_none = "None" in pipe_tokens
if return_has_none:
return has_none
if has_none:
anno_str = "|".join(p for p in pipe_tokens if p != "None")
return ForwardRef(anno_str) if anno_str else Never # type: ignore[return-value] # noqa: E501
return type_
def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
"""Make a Union type.
"""Make a Union type."""
return Union.__getitem__(types) # type: ignore
This is needed by :func:`.de_optionalize_union_types` which removes
``NoneType`` from a ``Union``.
def includes_none(type_: Any) -> bool:
"""Returns if the type annotation ``type_`` allows ``None``.
This function supports:
* forward refs
* unions
* pep593 - Annotated
* pep695 - TypeAliasType (does not support looking into
fw reference of other pep695)
* NewType
* plain types like ``int``, ``None``, etc
"""
return cast(Any, Union).__getitem__(types) # type: ignore
def expand_unions(
type_: Type[Any], include_union: bool = False, discard_none: bool = False
) -> Tuple[Type[Any], ...]:
"""Return a type as a tuple of individual types, expanding for
``Union`` types."""
if is_fwd_ref(type_):
return _de_optionalize_fwd_ref_union_types(type_, True)
if is_union(type_):
typ = set(type_.__args__)
if discard_none:
typ.discard(NoneType)
if include_union:
return (type_,) + tuple(typ) # type: ignore
else:
return tuple(typ) # type: ignore
else:
return (type_,)
return any(includes_none(t) for t in get_args(type_))
if is_pep593(type_):
return includes_none(get_args(type_)[0])
if is_pep695(type_):
return any(includes_none(t) for t in pep695_values(type_))
if is_newtype(type_):
return includes_none(type_.__supertype__)
return type_ in (NoneFwd, NoneType, None)
def is_optional(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
return is_origin_of(
type_,
"Optional",
"Union",
"UnionType",
)
def is_optional_union(type_: Any) -> bool:
return is_optional(type_) and NoneType in typing_get_args(type_)
def is_union(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
return is_origin_of(type_, "Union")
def is_union(type_: Any) -> TypeGuard[ArgsTypeProtocol]:
return is_origin_of(type_, "Union", "UnionType")
def is_origin_of_cls(
@@ -472,7 +584,7 @@ def is_origin_of_cls(
"""return True if the given type has an __origin__ that shares a base
with the given class"""
origin = typing_get_origin(type_)
origin = get_origin(type_)
if origin is None:
return False
@@ -485,7 +597,7 @@ def is_origin_of(
"""return True if the given type has an __origin__ with the given name
and optional module."""
origin = typing_get_origin(type_)
origin = get_origin(type_)
if origin is None:
return False
@@ -575,6 +687,3 @@ class CallableReference(Generic[_FN]):
def __set__(self, instance: Any, value: _FN) -> None: ...
def __delete__(self, instance: Any) -> None: ...
# $def ro_descriptor_reference(fn: Callable[])