1
0
mirror of https://gitlab.com/MoonTestUse1/AdministrationItDepartmens.git synced 2025-08-14 00:25:46 +02:00

Все подряд

This commit is contained in:
MoonTestUse1
2024-12-31 02:37:57 +06:00
parent 8e53bb6cb2
commit d5780b2eab
3258 changed files with 1087440 additions and 268 deletions

View File

@@ -0,0 +1,29 @@
from sqlalchemy.testing import config
from sqlalchemy.testing import emits_warning
from sqlalchemy.testing import engines
from sqlalchemy.testing import exclusions
from sqlalchemy.testing import mock
from sqlalchemy.testing import provide_metadata
from sqlalchemy.testing import skip_if
from sqlalchemy.testing import uses_deprecated
from sqlalchemy.testing.config import combinations
from sqlalchemy.testing.config import fixture
from sqlalchemy.testing.config import requirements as requires
from .assertions import assert_raises
from .assertions import assert_raises_message
from .assertions import emits_python_deprecation_warning
from .assertions import eq_
from .assertions import eq_ignore_whitespace
from .assertions import expect_raises
from .assertions import expect_raises_message
from .assertions import expect_sqlalchemy_deprecated
from .assertions import expect_sqlalchemy_deprecated_20
from .assertions import expect_warnings
from .assertions import is_
from .assertions import is_false
from .assertions import is_not_
from .assertions import is_true
from .assertions import ne_
from .fixtures import TestBase
from .util import resolve_lambda

View File

@@ -0,0 +1,167 @@
from __future__ import annotations
import contextlib
import re
import sys
from typing import Any
from typing import Dict
from sqlalchemy import exc as sa_exc
from sqlalchemy.engine import default
from sqlalchemy.testing.assertions import _expect_warnings
from sqlalchemy.testing.assertions import eq_ # noqa
from sqlalchemy.testing.assertions import is_ # noqa
from sqlalchemy.testing.assertions import is_false # noqa
from sqlalchemy.testing.assertions import is_not_ # noqa
from sqlalchemy.testing.assertions import is_true # noqa
from sqlalchemy.testing.assertions import ne_ # noqa
from sqlalchemy.util import decorator
from ..util import sqla_compat
def _assert_proper_exception_context(exception):
"""assert that any exception we're catching does not have a __context__
without a __cause__, and that __suppress_context__ is never set.
Python 3 will report nested as exceptions as "during the handling of
error X, error Y occurred". That's not what we want to do. we want
these exceptions in a cause chain.
"""
if (
exception.__context__ is not exception.__cause__
and not exception.__suppress_context__
):
assert False, (
"Exception %r was correctly raised but did not set a cause, "
"within context %r as its cause."
% (exception, exception.__context__)
)
def assert_raises(except_cls, callable_, *args, **kw):
return _assert_raises(except_cls, callable_, args, kw, check_context=True)
def assert_raises_context_ok(except_cls, callable_, *args, **kw):
return _assert_raises(except_cls, callable_, args, kw)
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
return _assert_raises(
except_cls, callable_, args, kwargs, msg=msg, check_context=True
)
def assert_raises_message_context_ok(
except_cls, msg, callable_, *args, **kwargs
):
return _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
def _assert_raises(
except_cls, callable_, args, kwargs, msg=None, check_context=False
):
with _expect_raises(except_cls, msg, check_context) as ec:
callable_(*args, **kwargs)
return ec.error
class _ErrorContainer:
error: Any = None
@contextlib.contextmanager
def _expect_raises(except_cls, msg=None, check_context=False):
ec = _ErrorContainer()
if check_context:
are_we_already_in_a_traceback = sys.exc_info()[0]
try:
yield ec
success = False
except except_cls as err:
ec.error = err
success = True
if msg is not None:
assert re.search(msg, str(err), re.UNICODE), f"{msg} !~ {err}"
if check_context and not are_we_already_in_a_traceback:
_assert_proper_exception_context(err)
print(str(err).encode("utf-8"))
# assert outside the block so it works for AssertionError too !
assert success, "Callable did not raise an exception"
def expect_raises(except_cls, check_context=True):
return _expect_raises(except_cls, check_context=check_context)
def expect_raises_message(except_cls, msg, check_context=True):
return _expect_raises(except_cls, msg=msg, check_context=check_context)
def eq_ignore_whitespace(a, b, msg=None):
a = re.sub(r"^\s+?|\n", "", a)
a = re.sub(r" {2,}", " ", a)
b = re.sub(r"^\s+?|\n", "", b)
b = re.sub(r" {2,}", " ", b)
assert a == b, msg or "%r != %r" % (a, b)
_dialect_mods: Dict[Any, Any] = {}
def _get_dialect(name):
if name is None or name == "default":
return default.DefaultDialect()
else:
d = sqla_compat._create_url(name).get_dialect()()
if name == "postgresql":
d.implicit_returning = True
elif name == "mssql":
d.legacy_schema_aliasing = False
return d
def expect_warnings(*messages, **kw):
"""Context manager which expects one or more warnings.
With no arguments, squelches all SAWarnings emitted via
sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
pass string expressions that will match selected warnings via regex;
all non-matching warnings are sent through.
The expect version **asserts** that the warnings were in fact seen.
Note that the test suite sets SAWarning warnings to raise exceptions.
"""
return _expect_warnings(Warning, messages, **kw)
def emits_python_deprecation_warning(*messages):
"""Decorator form of expect_warnings().
Note that emits_warning does **not** assert that the warnings
were in fact seen.
"""
@decorator
def decorate(fn, *args, **kw):
with _expect_warnings(DeprecationWarning, assert_=False, *messages):
return fn(*args, **kw)
return decorate
def expect_sqlalchemy_deprecated(*messages, **kw):
return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
def expect_sqlalchemy_deprecated_20(*messages, **kw):
return _expect_warnings(sa_exc.RemovedIn20Warning, messages, **kw)

View File

@@ -0,0 +1,518 @@
import importlib.machinery
import os
import shutil
import textwrap
from sqlalchemy.testing import config
from sqlalchemy.testing import provision
from . import util as testing_util
from .. import command
from .. import script
from .. import util
from ..script import Script
from ..script import ScriptDirectory
def _get_staging_directory():
if provision.FOLLOWER_IDENT:
return "scratch_%s" % provision.FOLLOWER_IDENT
else:
return "scratch"
def staging_env(create=True, template="generic", sourceless=False):
cfg = _testing_config()
if create:
path = os.path.join(_get_staging_directory(), "scripts")
assert not os.path.exists(path), (
"staging directory %s already exists; poor cleanup?" % path
)
command.init(cfg, path, template=template)
if sourceless:
try:
# do an import so that a .pyc/.pyo is generated.
util.load_python_file(path, "env.py")
except AttributeError:
# we don't have the migration context set up yet
# so running the .env py throws this exception.
# theoretically we could be using py_compiler here to
# generate .pyc/.pyo without importing but not really
# worth it.
pass
assert sourceless in (
"pep3147_envonly",
"simple",
"pep3147_everything",
), sourceless
make_sourceless(
os.path.join(path, "env.py"),
"pep3147" if "pep3147" in sourceless else "simple",
)
sc = script.ScriptDirectory.from_config(cfg)
return sc
def clear_staging_env():
from sqlalchemy.testing import engines
engines.testing_reaper.close_all()
shutil.rmtree(_get_staging_directory(), True)
def script_file_fixture(txt):
dir_ = os.path.join(_get_staging_directory(), "scripts")
path = os.path.join(dir_, "script.py.mako")
with open(path, "w") as f:
f.write(txt)
def env_file_fixture(txt):
dir_ = os.path.join(_get_staging_directory(), "scripts")
txt = (
"""
from alembic import context
config = context.config
"""
+ txt
)
path = os.path.join(dir_, "env.py")
pyc_path = util.pyc_file_from_path(path)
if pyc_path:
os.unlink(pyc_path)
with open(path, "w") as f:
f.write(txt)
def _sqlite_file_db(tempname="foo.db", future=False, scope=None, **options):
dir_ = os.path.join(_get_staging_directory(), "scripts")
url = "sqlite:///%s/%s" % (dir_, tempname)
if scope and util.sqla_14:
options["scope"] = scope
return testing_util.testing_engine(url=url, future=future, options=options)
def _sqlite_testing_config(sourceless=False, future=False):
dir_ = os.path.join(_get_staging_directory(), "scripts")
url = "sqlite:///%s/foo.db" % dir_
sqlalchemy_future = future or ("future" in config.db.__class__.__module__)
return _write_config_file(
"""
[alembic]
script_location = %s
sqlalchemy.url = %s
sourceless = %s
%s
[loggers]
keys = root,sqlalchemy
[handlers]
keys = console
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = DEBUG
handlers =
qualname = sqlalchemy.engine
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatters]
keys = generic
[formatter_generic]
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
"""
% (
dir_,
url,
"true" if sourceless else "false",
"sqlalchemy.future = true" if sqlalchemy_future else "",
)
)
def _multi_dir_testing_config(sourceless=False, extra_version_location=""):
dir_ = os.path.join(_get_staging_directory(), "scripts")
sqlalchemy_future = "future" in config.db.__class__.__module__
url = "sqlite:///%s/foo.db" % dir_
return _write_config_file(
"""
[alembic]
script_location = %s
sqlalchemy.url = %s
sqlalchemy.future = %s
sourceless = %s
version_locations = %%(here)s/model1/ %%(here)s/model2/ %%(here)s/model3/ %s
[loggers]
keys = root
[handlers]
keys = console
[logger_root]
level = WARN
handlers = console
qualname =
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatters]
keys = generic
[formatter_generic]
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
"""
% (
dir_,
url,
"true" if sqlalchemy_future else "false",
"true" if sourceless else "false",
extra_version_location,
)
)
def _no_sql_testing_config(dialect="postgresql", directives=""):
"""use a postgresql url with no host so that
connections guaranteed to fail"""
dir_ = os.path.join(_get_staging_directory(), "scripts")
return _write_config_file(
"""
[alembic]
script_location = %s
sqlalchemy.url = %s://
%s
[loggers]
keys = root
[handlers]
keys = console
[logger_root]
level = WARN
handlers = console
qualname =
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatters]
keys = generic
[formatter_generic]
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
"""
% (dir_, dialect, directives)
)
def _write_config_file(text):
cfg = _testing_config()
with open(cfg.config_file_name, "w") as f:
f.write(text)
return cfg
def _testing_config():
from alembic.config import Config
if not os.access(_get_staging_directory(), os.F_OK):
os.mkdir(_get_staging_directory())
return Config(os.path.join(_get_staging_directory(), "test_alembic.ini"))
def write_script(
scriptdir, rev_id, content, encoding="ascii", sourceless=False
):
old = scriptdir.revision_map.get_revision(rev_id)
path = old.path
content = textwrap.dedent(content)
if encoding:
content = content.encode(encoding)
with open(path, "wb") as fp:
fp.write(content)
pyc_path = util.pyc_file_from_path(path)
if pyc_path:
os.unlink(pyc_path)
script = Script._from_path(scriptdir, path)
old = scriptdir.revision_map.get_revision(script.revision)
if old.down_revision != script.down_revision:
raise Exception(
"Can't change down_revision " "on a refresh operation."
)
scriptdir.revision_map.add_revision(script, _replace=True)
if sourceless:
make_sourceless(
path, "pep3147" if sourceless == "pep3147_everything" else "simple"
)
def make_sourceless(path, style):
import py_compile
py_compile.compile(path)
if style == "simple":
pyc_path = util.pyc_file_from_path(path)
suffix = importlib.machinery.BYTECODE_SUFFIXES[0]
filepath, ext = os.path.splitext(path)
simple_pyc_path = filepath + suffix
shutil.move(pyc_path, simple_pyc_path)
pyc_path = simple_pyc_path
else:
assert style in ("pep3147", "simple")
pyc_path = util.pyc_file_from_path(path)
assert os.access(pyc_path, os.F_OK)
os.unlink(path)
def three_rev_fixture(cfg):
a = util.rev_id()
b = util.rev_id()
c = util.rev_id()
script = ScriptDirectory.from_config(cfg)
script.generate_revision(a, "revision a", refresh=True, head="base")
write_script(
script,
a,
"""\
"Rev A"
revision = '%s'
down_revision = None
from alembic import op
def upgrade():
op.execute("CREATE STEP 1")
def downgrade():
op.execute("DROP STEP 1")
"""
% a,
)
script.generate_revision(b, "revision b", refresh=True, head=a)
write_script(
script,
b,
f"""# coding: utf-8
"Rev B, méil, %3"
revision = '{b}'
down_revision = '{a}'
from alembic import op
def upgrade():
op.execute("CREATE STEP 2")
def downgrade():
op.execute("DROP STEP 2")
""",
encoding="utf-8",
)
script.generate_revision(c, "revision c", refresh=True, head=b)
write_script(
script,
c,
"""\
"Rev C"
revision = '%s'
down_revision = '%s'
from alembic import op
def upgrade():
op.execute("CREATE STEP 3")
def downgrade():
op.execute("DROP STEP 3")
"""
% (c, b),
)
return a, b, c
def multi_heads_fixture(cfg, a, b, c):
"""Create a multiple head fixture from the three-revs fixture"""
# a->b->c
# -> d -> e
# -> f
d = util.rev_id()
e = util.rev_id()
f = util.rev_id()
script = ScriptDirectory.from_config(cfg)
script.generate_revision(
d, "revision d from b", head=b, splice=True, refresh=True
)
write_script(
script,
d,
"""\
"Rev D"
revision = '%s'
down_revision = '%s'
from alembic import op
def upgrade():
op.execute("CREATE STEP 4")
def downgrade():
op.execute("DROP STEP 4")
"""
% (d, b),
)
script.generate_revision(
e, "revision e from d", head=d, splice=True, refresh=True
)
write_script(
script,
e,
"""\
"Rev E"
revision = '%s'
down_revision = '%s'
from alembic import op
def upgrade():
op.execute("CREATE STEP 5")
def downgrade():
op.execute("DROP STEP 5")
"""
% (e, d),
)
script.generate_revision(
f, "revision f from b", head=b, splice=True, refresh=True
)
write_script(
script,
f,
"""\
"Rev F"
revision = '%s'
down_revision = '%s'
from alembic import op
def upgrade():
op.execute("CREATE STEP 6")
def downgrade():
op.execute("DROP STEP 6")
"""
% (f, b),
)
return d, e, f
def _multidb_testing_config(engines):
"""alembic.ini fixture to work exactly with the 'multidb' template"""
dir_ = os.path.join(_get_staging_directory(), "scripts")
sqlalchemy_future = "future" in config.db.__class__.__module__
databases = ", ".join(engines.keys())
engines = "\n\n".join(
"[%s]\n" "sqlalchemy.url = %s" % (key, value.url)
for key, value in engines.items()
)
return _write_config_file(
"""
[alembic]
script_location = %s
sourceless = false
sqlalchemy.future = %s
databases = %s
%s
[loggers]
keys = root
[handlers]
keys = console
[logger_root]
level = WARN
handlers = console
qualname =
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatters]
keys = generic
[formatter_generic]
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
"""
% (dir_, "true" if sqlalchemy_future else "false", databases, engines)
)

View File

@@ -0,0 +1,306 @@
from __future__ import annotations
import configparser
from contextlib import contextmanager
import io
import re
from typing import Any
from typing import Dict
from sqlalchemy import Column
from sqlalchemy import inspect
from sqlalchemy import MetaData
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import testing
from sqlalchemy import text
from sqlalchemy.testing import config
from sqlalchemy.testing import mock
from sqlalchemy.testing.assertions import eq_
from sqlalchemy.testing.fixtures import TablesTest as SQLAlchemyTablesTest
from sqlalchemy.testing.fixtures import TestBase as SQLAlchemyTestBase
import alembic
from .assertions import _get_dialect
from ..environment import EnvironmentContext
from ..migration import MigrationContext
from ..operations import Operations
from ..util import sqla_compat
from ..util.sqla_compat import create_mock_engine
from ..util.sqla_compat import sqla_14
from ..util.sqla_compat import sqla_2
testing_config = configparser.ConfigParser()
testing_config.read(["test.cfg"])
class TestBase(SQLAlchemyTestBase):
is_sqlalchemy_future = sqla_2
@testing.fixture()
def ops_context(self, migration_context):
with migration_context.begin_transaction(_per_migration=True):
yield Operations(migration_context)
@testing.fixture
def migration_context(self, connection):
return MigrationContext.configure(
connection, opts=dict(transaction_per_migration=True)
)
@testing.fixture
def connection(self):
with config.db.connect() as conn:
yield conn
class TablesTest(TestBase, SQLAlchemyTablesTest):
pass
if sqla_14:
from sqlalchemy.testing.fixtures import FutureEngineMixin
else:
class FutureEngineMixin: # type:ignore[no-redef]
__requires__ = ("sqlalchemy_14",)
FutureEngineMixin.is_sqlalchemy_future = True
def capture_db(dialect="postgresql://"):
buf = []
def dump(sql, *multiparams, **params):
buf.append(str(sql.compile(dialect=engine.dialect)))
engine = create_mock_engine(dialect, dump)
return engine, buf
_engs: Dict[Any, Any] = {}
@contextmanager
def capture_context_buffer(**kw):
if kw.pop("bytes_io", False):
buf = io.BytesIO()
else:
buf = io.StringIO()
kw.update({"dialect_name": "sqlite", "output_buffer": buf})
conf = EnvironmentContext.configure
def configure(*arg, **opt):
opt.update(**kw)
return conf(*arg, **opt)
with mock.patch.object(EnvironmentContext, "configure", configure):
yield buf
@contextmanager
def capture_engine_context_buffer(**kw):
from .env import _sqlite_file_db
from sqlalchemy import event
buf = io.StringIO()
eng = _sqlite_file_db()
conn = eng.connect()
@event.listens_for(conn, "before_cursor_execute")
def bce(conn, cursor, statement, parameters, context, executemany):
buf.write(statement + "\n")
kw.update({"connection": conn})
conf = EnvironmentContext.configure
def configure(*arg, **opt):
opt.update(**kw)
return conf(*arg, **opt)
with mock.patch.object(EnvironmentContext, "configure", configure):
yield buf
def op_fixture(
dialect="default",
as_sql=False,
naming_convention=None,
literal_binds=False,
native_boolean=None,
):
opts = {}
if naming_convention:
opts["target_metadata"] = MetaData(naming_convention=naming_convention)
class buffer_:
def __init__(self):
self.lines = []
def write(self, msg):
msg = msg.strip()
msg = re.sub(r"[\n\t]", "", msg)
if as_sql:
# the impl produces soft tabs,
# so search for blocks of 4 spaces
msg = re.sub(r" ", "", msg)
msg = re.sub(r"\;\n*$", "", msg)
self.lines.append(msg)
def flush(self):
pass
buf = buffer_()
class ctx(MigrationContext):
def get_buf(self):
return buf
def clear_assertions(self):
buf.lines[:] = []
def assert_(self, *sql):
# TODO: make this more flexible about
# whitespace and such
eq_(buf.lines, [re.sub(r"[\n\t]", "", s) for s in sql])
def assert_contains(self, sql):
for stmt in buf.lines:
if re.sub(r"[\n\t]", "", sql) in stmt:
return
else:
assert False, "Could not locate fragment %r in %r" % (
sql,
buf.lines,
)
if as_sql:
opts["as_sql"] = as_sql
if literal_binds:
opts["literal_binds"] = literal_binds
if not sqla_14 and dialect == "mariadb":
ctx_dialect = _get_dialect("mysql")
ctx_dialect.server_version_info = (10, 4, 0, "MariaDB")
else:
ctx_dialect = _get_dialect(dialect)
if native_boolean is not None:
ctx_dialect.supports_native_boolean = native_boolean
# this is new as of SQLAlchemy 1.2.7 and is used by SQL Server,
# which breaks assumptions in the alembic test suite
ctx_dialect.non_native_boolean_check_constraint = True
if not as_sql:
def execute(stmt, *multiparam, **param):
if isinstance(stmt, str):
stmt = text(stmt)
assert stmt.supports_execution
sql = str(stmt.compile(dialect=ctx_dialect))
buf.write(sql)
connection = mock.Mock(dialect=ctx_dialect, execute=execute)
else:
opts["output_buffer"] = buf
connection = None
context = ctx(ctx_dialect, connection, opts)
alembic.op._proxy = Operations(context)
return context
class AlterColRoundTripFixture:
# since these tests are about syntax, use more recent SQLAlchemy as some of
# the type / server default compare logic might not work on older
# SQLAlchemy versions as seems to be the case for SQLAlchemy 1.1 on Oracle
__requires__ = ("alter_column",)
def setUp(self):
self.conn = config.db.connect()
self.ctx = MigrationContext.configure(self.conn)
self.op = Operations(self.ctx)
self.metadata = MetaData()
def _compare_type(self, t1, t2):
c1 = Column("q", t1)
c2 = Column("q", t2)
assert not self.ctx.impl.compare_type(
c1, c2
), "Type objects %r and %r didn't compare as equivalent" % (t1, t2)
def _compare_server_default(self, t1, s1, t2, s2):
c1 = Column("q", t1, server_default=s1)
c2 = Column("q", t2, server_default=s2)
assert not self.ctx.impl.compare_server_default(
c1, c2, s2, s1
), "server defaults %r and %r didn't compare as equivalent" % (s1, s2)
def tearDown(self):
sqla_compat._safe_rollback_connection_transaction(self.conn)
with self.conn.begin():
self.metadata.drop_all(self.conn)
self.conn.close()
def _run_alter_col(self, from_, to_, compare=None):
column = Column(
from_.get("name", "colname"),
from_.get("type", String(10)),
nullable=from_.get("nullable", True),
server_default=from_.get("server_default", None),
# comment=from_.get("comment", None)
)
t = Table("x", self.metadata, column)
with sqla_compat._ensure_scope_for_ddl(self.conn):
t.create(self.conn)
insp = inspect(self.conn)
old_col = insp.get_columns("x")[0]
# TODO: conditional comment support
self.op.alter_column(
"x",
column.name,
existing_type=column.type,
existing_server_default=column.server_default
if column.server_default is not None
else False,
existing_nullable=True if column.nullable else False,
# existing_comment=column.comment,
nullable=to_.get("nullable", None),
# modify_comment=False,
server_default=to_.get("server_default", False),
new_column_name=to_.get("name", None),
type_=to_.get("type", None),
)
insp = inspect(self.conn)
new_col = insp.get_columns("x")[0]
if compare is None:
compare = to_
eq_(
new_col["name"],
compare["name"] if "name" in compare else column.name,
)
self._compare_type(
new_col["type"], compare.get("type", old_col["type"])
)
eq_(new_col["nullable"], compare.get("nullable", column.nullable))
self._compare_server_default(
new_col["type"],
new_col.get("default", None),
compare.get("type", old_col["type"]),
compare["server_default"].text
if "server_default" in compare
else column.server_default.arg.text
if column.server_default is not None
else None,
)

View File

@@ -0,0 +1,4 @@
"""
Bootstrapper for test framework plugins.
"""

View File

@@ -0,0 +1,210 @@
from sqlalchemy.testing.requirements import Requirements
from alembic import util
from alembic.util import sqla_compat
from ..testing import exclusions
class SuiteRequirements(Requirements):
@property
def schemas(self):
"""Target database must support external schemas, and have one
named 'test_schema'."""
return exclusions.open()
@property
def autocommit_isolation(self):
"""target database should support 'AUTOCOMMIT' isolation level"""
return exclusions.closed()
@property
def materialized_views(self):
"""needed for sqlalchemy compat"""
return exclusions.closed()
@property
def unique_constraint_reflection(self):
def doesnt_have_check_uq_constraints(config):
from sqlalchemy import inspect
insp = inspect(config.db)
try:
insp.get_unique_constraints("x")
except NotImplementedError:
return True
except TypeError:
return True
except Exception:
pass
return False
return exclusions.skip_if(doesnt_have_check_uq_constraints)
@property
def sequences(self):
"""Target database must support SEQUENCEs."""
return exclusions.only_if(
[lambda config: config.db.dialect.supports_sequences],
"no sequence support",
)
@property
def foreign_key_match(self):
return exclusions.open()
@property
def foreign_key_constraint_reflection(self):
return exclusions.open()
@property
def check_constraints_w_enforcement(self):
"""Target database must support check constraints
and also enforce them."""
return exclusions.open()
@property
def reflects_pk_names(self):
return exclusions.closed()
@property
def reflects_fk_options(self):
return exclusions.closed()
@property
def sqlalchemy_14(self):
return exclusions.skip_if(
lambda config: not util.sqla_14,
"SQLAlchemy 1.4 or greater required",
)
@property
def sqlalchemy_1x(self):
return exclusions.skip_if(
lambda config: util.sqla_2,
"SQLAlchemy 1.x test",
)
@property
def sqlalchemy_2(self):
return exclusions.skip_if(
lambda config: not util.sqla_2,
"SQLAlchemy 2.x test",
)
@property
def asyncio(self):
def go(config):
try:
import greenlet # noqa: F401
except ImportError:
return False
else:
return True
return self.sqlalchemy_14 + exclusions.only_if(go)
@property
def comments(self):
return exclusions.only_if(
lambda config: config.db.dialect.supports_comments
)
@property
def alter_column(self):
return exclusions.open()
@property
def computed_columns(self):
return exclusions.closed()
@property
def computed_columns_api(self):
return exclusions.only_if(
exclusions.BooleanPredicate(sqla_compat.has_computed)
)
@property
def computed_reflects_normally(self):
return exclusions.only_if(
exclusions.BooleanPredicate(sqla_compat.has_computed_reflection)
)
@property
def computed_reflects_as_server_default(self):
return exclusions.closed()
@property
def computed_doesnt_reflect_as_server_default(self):
return exclusions.closed()
@property
def autoincrement_on_composite_pk(self):
return exclusions.closed()
@property
def fk_ondelete_is_reflected(self):
return exclusions.closed()
@property
def fk_onupdate_is_reflected(self):
return exclusions.closed()
@property
def fk_onupdate(self):
return exclusions.open()
@property
def fk_ondelete_restrict(self):
return exclusions.open()
@property
def fk_onupdate_restrict(self):
return exclusions.open()
@property
def fk_ondelete_noaction(self):
return exclusions.open()
@property
def fk_initially(self):
return exclusions.closed()
@property
def fk_deferrable(self):
return exclusions.closed()
@property
def fk_deferrable_is_reflected(self):
return exclusions.closed()
@property
def fk_names(self):
return exclusions.open()
@property
def integer_subtype_comparisons(self):
return exclusions.open()
@property
def no_name_normalize(self):
return exclusions.skip_if(
lambda config: config.db.dialect.requires_name_normalize
)
@property
def identity_columns(self):
return exclusions.closed()
@property
def identity_columns_alter(self):
return exclusions.closed()
@property
def identity_columns_api(self):
return exclusions.only_if(
exclusions.BooleanPredicate(sqla_compat.has_identity)
)

View File

@@ -0,0 +1,169 @@
from itertools import zip_longest
from sqlalchemy import schema
from sqlalchemy.sql.elements import ClauseList
class CompareTable:
def __init__(self, table):
self.table = table
def __eq__(self, other):
if self.table.name != other.name or self.table.schema != other.schema:
return False
for c1, c2 in zip_longest(self.table.c, other.c):
if (c1 is None and c2 is not None) or (
c2 is None and c1 is not None
):
return False
if CompareColumn(c1) != c2:
return False
return True
# TODO: compare constraints, indexes
def __ne__(self, other):
return not self.__eq__(other)
class CompareColumn:
def __init__(self, column):
self.column = column
def __eq__(self, other):
return (
self.column.name == other.name
and self.column.nullable == other.nullable
)
# TODO: datatypes etc
def __ne__(self, other):
return not self.__eq__(other)
class CompareIndex:
def __init__(self, index, name_only=False):
self.index = index
self.name_only = name_only
def __eq__(self, other):
if self.name_only:
return self.index.name == other.name
else:
return (
str(schema.CreateIndex(self.index))
== str(schema.CreateIndex(other))
and self.index.dialect_kwargs == other.dialect_kwargs
)
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
expr = ClauseList(*self.index.expressions)
try:
expr_str = expr.compile().string
except Exception:
expr_str = str(expr)
return f"<CompareIndex {self.index.name}({expr_str})>"
class CompareCheckConstraint:
def __init__(self, constraint):
self.constraint = constraint
def __eq__(self, other):
return (
isinstance(other, schema.CheckConstraint)
and self.constraint.name == other.name
and (str(self.constraint.sqltext) == str(other.sqltext))
and (other.table.name == self.constraint.table.name)
and other.table.schema == self.constraint.table.schema
)
def __ne__(self, other):
return not self.__eq__(other)
class CompareForeignKey:
def __init__(self, constraint):
self.constraint = constraint
def __eq__(self, other):
r1 = (
isinstance(other, schema.ForeignKeyConstraint)
and self.constraint.name == other.name
and (other.table.name == self.constraint.table.name)
and other.table.schema == self.constraint.table.schema
)
if not r1:
return False
for c1, c2 in zip_longest(self.constraint.columns, other.columns):
if (c1 is None and c2 is not None) or (
c2 is None and c1 is not None
):
return False
if CompareColumn(c1) != c2:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
class ComparePrimaryKey:
def __init__(self, constraint):
self.constraint = constraint
def __eq__(self, other):
r1 = (
isinstance(other, schema.PrimaryKeyConstraint)
and self.constraint.name == other.name
and (other.table.name == self.constraint.table.name)
and other.table.schema == self.constraint.table.schema
)
if not r1:
return False
for c1, c2 in zip_longest(self.constraint.columns, other.columns):
if (c1 is None and c2 is not None) or (
c2 is None and c1 is not None
):
return False
if CompareColumn(c1) != c2:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
class CompareUniqueConstraint:
def __init__(self, constraint):
self.constraint = constraint
def __eq__(self, other):
r1 = (
isinstance(other, schema.UniqueConstraint)
and self.constraint.name == other.name
and (other.table.name == self.constraint.table.name)
and other.table.schema == self.constraint.table.schema
)
if not r1:
return False
for c1, c2 in zip_longest(self.constraint.columns, other.columns):
if (c1 is None and c2 is not None) or (
c2 is None and c1 is not None
):
return False
if CompareColumn(c1) != c2:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)

View File

@@ -0,0 +1,7 @@
from .test_autogen_comments import * # noqa
from .test_autogen_computed import * # noqa
from .test_autogen_diffs import * # noqa
from .test_autogen_fks import * # noqa
from .test_autogen_identity import * # noqa
from .test_environment import * # noqa
from .test_op import * # noqa

View File

@@ -0,0 +1,335 @@
from __future__ import annotations
from typing import Any
from typing import Dict
from typing import Set
from sqlalchemy import CHAR
from sqlalchemy import CheckConstraint
from sqlalchemy import Column
from sqlalchemy import event
from sqlalchemy import ForeignKey
from sqlalchemy import Index
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import Numeric
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import Text
from sqlalchemy import text
from sqlalchemy import UniqueConstraint
from ... import autogenerate
from ... import util
from ...autogenerate import api
from ...ddl.base import _fk_spec
from ...migration import MigrationContext
from ...operations import ops
from ...testing import config
from ...testing import eq_
from ...testing.env import clear_staging_env
from ...testing.env import staging_env
names_in_this_test: Set[Any] = set()
@event.listens_for(Table, "after_parent_attach")
def new_table(table, parent):
names_in_this_test.add(table.name)
def _default_include_object(obj, name, type_, reflected, compare_to):
if type_ == "table":
return name in names_in_this_test
else:
return True
_default_object_filters: Any = _default_include_object
_default_name_filters: Any = None
class ModelOne:
__requires__ = ("unique_constraint_reflection",)
schema: Any = None
@classmethod
def _get_db_schema(cls):
schema = cls.schema
m = MetaData(schema=schema)
Table(
"user",
m,
Column("id", Integer, primary_key=True),
Column("name", String(50)),
Column("a1", Text),
Column("pw", String(50)),
Index("pw_idx", "pw"),
)
Table(
"address",
m,
Column("id", Integer, primary_key=True),
Column("email_address", String(100), nullable=False),
)
Table(
"order",
m,
Column("order_id", Integer, primary_key=True),
Column(
"amount",
Numeric(8, 2),
nullable=False,
server_default=text("0"),
),
CheckConstraint("amount >= 0", name="ck_order_amount"),
)
Table(
"extra",
m,
Column("x", CHAR),
Column("uid", Integer, ForeignKey("user.id")),
)
return m
@classmethod
def _get_model_schema(cls):
schema = cls.schema
m = MetaData(schema=schema)
Table(
"user",
m,
Column("id", Integer, primary_key=True),
Column("name", String(50), nullable=False),
Column("a1", Text, server_default="x"),
)
Table(
"address",
m,
Column("id", Integer, primary_key=True),
Column("email_address", String(100), nullable=False),
Column("street", String(50)),
UniqueConstraint("email_address", name="uq_email"),
)
Table(
"order",
m,
Column("order_id", Integer, primary_key=True),
Column(
"amount",
Numeric(10, 2),
nullable=True,
server_default=text("0"),
),
Column("user_id", Integer, ForeignKey("user.id")),
CheckConstraint("amount > -1", name="ck_order_amount"),
)
Table(
"item",
m,
Column("id", Integer, primary_key=True),
Column("description", String(100)),
Column("order_id", Integer, ForeignKey("order.order_id")),
CheckConstraint("len(description) > 5"),
)
return m
class _ComparesFKs:
def _assert_fk_diff(
self,
diff,
type_,
source_table,
source_columns,
target_table,
target_columns,
name=None,
conditional_name=None,
source_schema=None,
onupdate=None,
ondelete=None,
initially=None,
deferrable=None,
):
# the public API for ForeignKeyConstraint was not very rich
# in 0.7, 0.8, so here we use the well-known but slightly
# private API to get at its elements
(
fk_source_schema,
fk_source_table,
fk_source_columns,
fk_target_schema,
fk_target_table,
fk_target_columns,
fk_onupdate,
fk_ondelete,
fk_deferrable,
fk_initially,
) = _fk_spec(diff[1])
eq_(diff[0], type_)
eq_(fk_source_table, source_table)
eq_(fk_source_columns, source_columns)
eq_(fk_target_table, target_table)
eq_(fk_source_schema, source_schema)
eq_(fk_onupdate, onupdate)
eq_(fk_ondelete, ondelete)
eq_(fk_initially, initially)
eq_(fk_deferrable, deferrable)
eq_([elem.column.name for elem in diff[1].elements], target_columns)
if conditional_name is not None:
if conditional_name == "servergenerated":
fks = inspect(self.bind).get_foreign_keys(source_table)
server_fk_name = fks[0]["name"]
eq_(diff[1].name, server_fk_name)
else:
eq_(diff[1].name, conditional_name)
else:
eq_(diff[1].name, name)
class AutogenTest(_ComparesFKs):
def _flatten_diffs(self, diffs):
for d in diffs:
if isinstance(d, list):
yield from self._flatten_diffs(d)
else:
yield d
@classmethod
def _get_bind(cls):
return config.db
configure_opts: Dict[Any, Any] = {}
@classmethod
def setup_class(cls):
staging_env()
cls.bind = cls._get_bind()
cls.m1 = cls._get_db_schema()
cls.m1.create_all(cls.bind)
cls.m2 = cls._get_model_schema()
@classmethod
def teardown_class(cls):
cls.m1.drop_all(cls.bind)
clear_staging_env()
def setUp(self):
self.conn = conn = self.bind.connect()
ctx_opts = {
"compare_type": True,
"compare_server_default": True,
"target_metadata": self.m2,
"upgrade_token": "upgrades",
"downgrade_token": "downgrades",
"alembic_module_prefix": "op.",
"sqlalchemy_module_prefix": "sa.",
"include_object": _default_object_filters,
"include_name": _default_name_filters,
}
if self.configure_opts:
ctx_opts.update(self.configure_opts)
self.context = context = MigrationContext.configure(
connection=conn, opts=ctx_opts
)
self.autogen_context = api.AutogenContext(context, self.m2)
def tearDown(self):
self.conn.close()
def _update_context(
self, object_filters=None, name_filters=None, include_schemas=None
):
if include_schemas is not None:
self.autogen_context.opts["include_schemas"] = include_schemas
if object_filters is not None:
self.autogen_context._object_filters = [object_filters]
if name_filters is not None:
self.autogen_context._name_filters = [name_filters]
return self.autogen_context
class AutogenFixtureTest(_ComparesFKs):
def _fixture(
self,
m1,
m2,
include_schemas=False,
opts=None,
object_filters=_default_object_filters,
name_filters=_default_name_filters,
return_ops=False,
max_identifier_length=None,
):
if max_identifier_length:
dialect = self.bind.dialect
existing_length = dialect.max_identifier_length
dialect.max_identifier_length = (
dialect._user_defined_max_identifier_length
) = max_identifier_length
try:
self._alembic_metadata, model_metadata = m1, m2
for m in util.to_list(self._alembic_metadata):
m.create_all(self.bind)
with self.bind.connect() as conn:
ctx_opts = {
"compare_type": True,
"compare_server_default": True,
"target_metadata": model_metadata,
"upgrade_token": "upgrades",
"downgrade_token": "downgrades",
"alembic_module_prefix": "op.",
"sqlalchemy_module_prefix": "sa.",
"include_object": object_filters,
"include_name": name_filters,
"include_schemas": include_schemas,
}
if opts:
ctx_opts.update(opts)
self.context = context = MigrationContext.configure(
connection=conn, opts=ctx_opts
)
autogen_context = api.AutogenContext(context, model_metadata)
uo = ops.UpgradeOps(ops=[])
autogenerate._produce_net_changes(autogen_context, uo)
if return_ops:
return uo
else:
return uo.as_diffs()
finally:
if max_identifier_length:
dialect = self.bind.dialect
dialect.max_identifier_length = (
dialect._user_defined_max_identifier_length
) = existing_length
def setUp(self):
staging_env()
self.bind = config.db
def tearDown(self):
if hasattr(self, "_alembic_metadata"):
for m in util.to_list(self._alembic_metadata):
m.drop_all(self.bind)
clear_staging_env()

View File

@@ -0,0 +1,242 @@
from sqlalchemy import Column
from sqlalchemy import Float
from sqlalchemy import MetaData
from sqlalchemy import String
from sqlalchemy import Table
from ._autogen_fixtures import AutogenFixtureTest
from ...testing import eq_
from ...testing import mock
from ...testing import TestBase
class AutogenerateCommentsTest(AutogenFixtureTest, TestBase):
__backend__ = True
__requires__ = ("comments",)
def test_existing_table_comment_no_change(self):
m1 = MetaData()
m2 = MetaData()
Table(
"some_table",
m1,
Column("test", String(10), primary_key=True),
comment="this is some table",
)
Table(
"some_table",
m2,
Column("test", String(10), primary_key=True),
comment="this is some table",
)
diffs = self._fixture(m1, m2)
eq_(diffs, [])
def test_add_table_comment(self):
m1 = MetaData()
m2 = MetaData()
Table("some_table", m1, Column("test", String(10), primary_key=True))
Table(
"some_table",
m2,
Column("test", String(10), primary_key=True),
comment="this is some table",
)
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "add_table_comment")
eq_(diffs[0][1].comment, "this is some table")
eq_(diffs[0][2], None)
def test_remove_table_comment(self):
m1 = MetaData()
m2 = MetaData()
Table(
"some_table",
m1,
Column("test", String(10), primary_key=True),
comment="this is some table",
)
Table("some_table", m2, Column("test", String(10), primary_key=True))
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "remove_table_comment")
eq_(diffs[0][1].comment, None)
def test_alter_table_comment(self):
m1 = MetaData()
m2 = MetaData()
Table(
"some_table",
m1,
Column("test", String(10), primary_key=True),
comment="this is some table",
)
Table(
"some_table",
m2,
Column("test", String(10), primary_key=True),
comment="this is also some table",
)
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "add_table_comment")
eq_(diffs[0][1].comment, "this is also some table")
eq_(diffs[0][2], "this is some table")
def test_existing_column_comment_no_change(self):
m1 = MetaData()
m2 = MetaData()
Table(
"some_table",
m1,
Column("test", String(10), primary_key=True),
Column("amount", Float, comment="the amount"),
)
Table(
"some_table",
m2,
Column("test", String(10), primary_key=True),
Column("amount", Float, comment="the amount"),
)
diffs = self._fixture(m1, m2)
eq_(diffs, [])
def test_add_column_comment(self):
m1 = MetaData()
m2 = MetaData()
Table(
"some_table",
m1,
Column("test", String(10), primary_key=True),
Column("amount", Float),
)
Table(
"some_table",
m2,
Column("test", String(10), primary_key=True),
Column("amount", Float, comment="the amount"),
)
diffs = self._fixture(m1, m2)
eq_(
diffs,
[
[
(
"modify_comment",
None,
"some_table",
"amount",
{
"existing_nullable": True,
"existing_type": mock.ANY,
"existing_server_default": False,
},
None,
"the amount",
)
]
],
)
def test_remove_column_comment(self):
m1 = MetaData()
m2 = MetaData()
Table(
"some_table",
m1,
Column("test", String(10), primary_key=True),
Column("amount", Float, comment="the amount"),
)
Table(
"some_table",
m2,
Column("test", String(10), primary_key=True),
Column("amount", Float),
)
diffs = self._fixture(m1, m2)
eq_(
diffs,
[
[
(
"modify_comment",
None,
"some_table",
"amount",
{
"existing_nullable": True,
"existing_type": mock.ANY,
"existing_server_default": False,
},
"the amount",
None,
)
]
],
)
def test_alter_column_comment(self):
m1 = MetaData()
m2 = MetaData()
Table(
"some_table",
m1,
Column("test", String(10), primary_key=True),
Column("amount", Float, comment="the amount"),
)
Table(
"some_table",
m2,
Column("test", String(10), primary_key=True),
Column("amount", Float, comment="the adjusted amount"),
)
diffs = self._fixture(m1, m2)
eq_(
diffs,
[
[
(
"modify_comment",
None,
"some_table",
"amount",
{
"existing_nullable": True,
"existing_type": mock.ANY,
"existing_server_default": False,
},
"the amount",
"the adjusted amount",
)
]
],
)

View File

@@ -0,0 +1,203 @@
import sqlalchemy as sa
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import Table
from ._autogen_fixtures import AutogenFixtureTest
from ... import testing
from ...testing import config
from ...testing import eq_
from ...testing import exclusions
from ...testing import is_
from ...testing import is_true
from ...testing import mock
from ...testing import TestBase
class AutogenerateComputedTest(AutogenFixtureTest, TestBase):
__requires__ = ("computed_columns",)
__backend__ = True
def test_add_computed_column(self):
m1 = MetaData()
m2 = MetaData()
Table("user", m1, Column("id", Integer, primary_key=True))
Table(
"user",
m2,
Column("id", Integer, primary_key=True),
Column("foo", Integer, sa.Computed("5")),
)
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "add_column")
eq_(diffs[0][2], "user")
eq_(diffs[0][3].name, "foo")
c = diffs[0][3].computed
is_true(isinstance(c, sa.Computed))
is_(c.persisted, None)
eq_(str(c.sqltext), "5")
def test_remove_computed_column(self):
m1 = MetaData()
m2 = MetaData()
Table(
"user",
m1,
Column("id", Integer, primary_key=True),
Column("foo", Integer, sa.Computed("5")),
)
Table("user", m2, Column("id", Integer, primary_key=True))
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "remove_column")
eq_(diffs[0][2], "user")
c = diffs[0][3]
eq_(c.name, "foo")
if config.requirements.computed_reflects_normally.enabled:
is_true(isinstance(c.computed, sa.Computed))
else:
is_(c.computed, None)
if config.requirements.computed_reflects_as_server_default.enabled:
is_true(isinstance(c.server_default, sa.DefaultClause))
eq_(str(c.server_default.arg.text), "5")
elif config.requirements.computed_reflects_normally.enabled:
is_true(isinstance(c.computed, sa.Computed))
else:
is_(c.computed, None)
@testing.combinations(
lambda: (None, sa.Computed("bar*5")),
(lambda: (sa.Computed("bar*5"), None)),
lambda: (
sa.Computed("bar*5"),
sa.Computed("bar * 42", persisted=True),
),
lambda: (sa.Computed("bar*5"), sa.Computed("bar * 42")),
)
@config.requirements.computed_reflects_normally
def test_cant_change_computed_warning(self, test_case):
arg_before, arg_after = testing.resolve_lambda(test_case, **locals())
m1 = MetaData()
m2 = MetaData()
arg_before = [] if arg_before is None else [arg_before]
arg_after = [] if arg_after is None else [arg_after]
Table(
"user",
m1,
Column("id", Integer, primary_key=True),
Column("bar", Integer),
Column("foo", Integer, *arg_before),
)
Table(
"user",
m2,
Column("id", Integer, primary_key=True),
Column("bar", Integer),
Column("foo", Integer, *arg_after),
)
with mock.patch("alembic.util.warn") as mock_warn:
diffs = self._fixture(m1, m2)
eq_(
mock_warn.mock_calls,
[mock.call("Computed default on user.foo cannot be modified")],
)
eq_(list(diffs), [])
@testing.combinations(
lambda: (None, None),
lambda: (sa.Computed("5"), sa.Computed("5")),
lambda: (sa.Computed("bar*5"), sa.Computed("bar*5")),
(
lambda: (sa.Computed("bar*5"), None),
config.requirements.computed_doesnt_reflect_as_server_default,
),
)
def test_computed_unchanged(self, test_case):
arg_before, arg_after = testing.resolve_lambda(test_case, **locals())
m1 = MetaData()
m2 = MetaData()
arg_before = [] if arg_before is None else [arg_before]
arg_after = [] if arg_after is None else [arg_after]
Table(
"user",
m1,
Column("id", Integer, primary_key=True),
Column("bar", Integer),
Column("foo", Integer, *arg_before),
)
Table(
"user",
m2,
Column("id", Integer, primary_key=True),
Column("bar", Integer),
Column("foo", Integer, *arg_after),
)
with mock.patch("alembic.util.warn") as mock_warn:
diffs = self._fixture(m1, m2)
eq_(mock_warn.mock_calls, [])
eq_(list(diffs), [])
@config.requirements.computed_reflects_as_server_default
def test_remove_computed_default_on_computed(self):
"""Asserts the current behavior which is that on PG and Oracle,
the GENERATED ALWAYS AS is reflected as a server default which we can't
tell is actually "computed", so these come out as a modification to
the server default.
"""
m1 = MetaData()
m2 = MetaData()
Table(
"user",
m1,
Column("id", Integer, primary_key=True),
Column("bar", Integer),
Column("foo", Integer, sa.Computed("bar + 42")),
)
Table(
"user",
m2,
Column("id", Integer, primary_key=True),
Column("bar", Integer),
Column("foo", Integer),
)
diffs = self._fixture(m1, m2)
eq_(diffs[0][0][0], "modify_default")
eq_(diffs[0][0][2], "user")
eq_(diffs[0][0][3], "foo")
old = diffs[0][0][-2]
new = diffs[0][0][-1]
is_(new, None)
is_true(isinstance(old, sa.DefaultClause))
if exclusions.against(config, "postgresql"):
eq_(str(old.arg.text), "(bar + 42)")
elif exclusions.against(config, "oracle"):
eq_(str(old.arg.text), '"BAR"+42')

View File

@@ -0,0 +1,273 @@
from sqlalchemy import BigInteger
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import Table
from sqlalchemy.testing import in_
from ._autogen_fixtures import AutogenFixtureTest
from ... import testing
from ...testing import config
from ...testing import eq_
from ...testing import is_
from ...testing import TestBase
class AlterColumnTest(AutogenFixtureTest, TestBase):
__backend__ = True
@testing.combinations((True,), (False,))
@config.requirements.comments
def test_all_existings_filled(self, pk):
m1 = MetaData()
m2 = MetaData()
Table("a", m1, Column("x", Integer, primary_key=pk))
Table("a", m2, Column("x", Integer, comment="x", primary_key=pk))
alter_col = self._assert_alter_col(m1, m2, pk)
eq_(alter_col.modify_comment, "x")
@testing.combinations((True,), (False,))
@config.requirements.comments
def test_all_existings_filled_in_notnull(self, pk):
m1 = MetaData()
m2 = MetaData()
Table("a", m1, Column("x", Integer, nullable=False, primary_key=pk))
Table(
"a",
m2,
Column("x", Integer, nullable=False, comment="x", primary_key=pk),
)
self._assert_alter_col(m1, m2, pk, nullable=False)
@testing.combinations((True,), (False,))
@config.requirements.comments
def test_all_existings_filled_in_comment(self, pk):
m1 = MetaData()
m2 = MetaData()
Table("a", m1, Column("x", Integer, comment="old", primary_key=pk))
Table("a", m2, Column("x", Integer, comment="new", primary_key=pk))
alter_col = self._assert_alter_col(m1, m2, pk)
eq_(alter_col.existing_comment, "old")
@testing.combinations((True,), (False,))
@config.requirements.comments
def test_all_existings_filled_in_server_default(self, pk):
m1 = MetaData()
m2 = MetaData()
Table(
"a", m1, Column("x", Integer, server_default="5", primary_key=pk)
)
Table(
"a",
m2,
Column(
"x", Integer, server_default="5", comment="new", primary_key=pk
),
)
alter_col = self._assert_alter_col(m1, m2, pk)
in_("5", alter_col.existing_server_default.arg.text)
def _assert_alter_col(self, m1, m2, pk, nullable=None):
ops = self._fixture(m1, m2, return_ops=True)
modify_table = ops.ops[-1]
alter_col = modify_table.ops[0]
if nullable is None:
eq_(alter_col.existing_nullable, not pk)
else:
eq_(alter_col.existing_nullable, nullable)
assert alter_col.existing_type._compare_type_affinity(Integer())
return alter_col
class AutoincrementTest(AutogenFixtureTest, TestBase):
__backend__ = True
__requires__ = ("integer_subtype_comparisons",)
def test_alter_column_autoincrement_none(self):
m1 = MetaData()
m2 = MetaData()
Table("a", m1, Column("x", Integer, nullable=False))
Table("a", m2, Column("x", Integer, nullable=True))
ops = self._fixture(m1, m2, return_ops=True)
assert "autoincrement" not in ops.ops[0].ops[0].kw
def test_alter_column_autoincrement_pk_false(self):
m1 = MetaData()
m2 = MetaData()
Table(
"a",
m1,
Column("x", Integer, primary_key=True, autoincrement=False),
)
Table(
"a",
m2,
Column("x", BigInteger, primary_key=True, autoincrement=False),
)
ops = self._fixture(m1, m2, return_ops=True)
is_(ops.ops[0].ops[0].kw["autoincrement"], False)
def test_alter_column_autoincrement_pk_implicit_true(self):
m1 = MetaData()
m2 = MetaData()
Table("a", m1, Column("x", Integer, primary_key=True))
Table("a", m2, Column("x", BigInteger, primary_key=True))
ops = self._fixture(m1, m2, return_ops=True)
is_(ops.ops[0].ops[0].kw["autoincrement"], True)
def test_alter_column_autoincrement_pk_explicit_true(self):
m1 = MetaData()
m2 = MetaData()
Table(
"a", m1, Column("x", Integer, primary_key=True, autoincrement=True)
)
Table(
"a",
m2,
Column("x", BigInteger, primary_key=True, autoincrement=True),
)
ops = self._fixture(m1, m2, return_ops=True)
is_(ops.ops[0].ops[0].kw["autoincrement"], True)
def test_alter_column_autoincrement_nonpk_false(self):
m1 = MetaData()
m2 = MetaData()
Table(
"a",
m1,
Column("id", Integer, primary_key=True),
Column("x", Integer, autoincrement=False),
)
Table(
"a",
m2,
Column("id", Integer, primary_key=True),
Column("x", BigInteger, autoincrement=False),
)
ops = self._fixture(m1, m2, return_ops=True)
is_(ops.ops[0].ops[0].kw["autoincrement"], False)
def test_alter_column_autoincrement_nonpk_implicit_false(self):
m1 = MetaData()
m2 = MetaData()
Table(
"a",
m1,
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
Table(
"a",
m2,
Column("id", Integer, primary_key=True),
Column("x", BigInteger),
)
ops = self._fixture(m1, m2, return_ops=True)
assert "autoincrement" not in ops.ops[0].ops[0].kw
def test_alter_column_autoincrement_nonpk_explicit_true(self):
m1 = MetaData()
m2 = MetaData()
Table(
"a",
m1,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("x", Integer, autoincrement=True),
)
Table(
"a",
m2,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("x", BigInteger, autoincrement=True),
)
ops = self._fixture(m1, m2, return_ops=True)
is_(ops.ops[0].ops[0].kw["autoincrement"], True)
def test_alter_column_autoincrement_compositepk_false(self):
m1 = MetaData()
m2 = MetaData()
Table(
"a",
m1,
Column("id", Integer, primary_key=True),
Column("x", Integer, primary_key=True, autoincrement=False),
)
Table(
"a",
m2,
Column("id", Integer, primary_key=True),
Column("x", BigInteger, primary_key=True, autoincrement=False),
)
ops = self._fixture(m1, m2, return_ops=True)
is_(ops.ops[0].ops[0].kw["autoincrement"], False)
def test_alter_column_autoincrement_compositepk_implicit_false(self):
m1 = MetaData()
m2 = MetaData()
Table(
"a",
m1,
Column("id", Integer, primary_key=True),
Column("x", Integer, primary_key=True),
)
Table(
"a",
m2,
Column("id", Integer, primary_key=True),
Column("x", BigInteger, primary_key=True),
)
ops = self._fixture(m1, m2, return_ops=True)
assert "autoincrement" not in ops.ops[0].ops[0].kw
@config.requirements.autoincrement_on_composite_pk
def test_alter_column_autoincrement_compositepk_explicit_true(self):
m1 = MetaData()
m2 = MetaData()
Table(
"a",
m1,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("x", Integer, primary_key=True, autoincrement=True),
# on SQLA 1.0 and earlier, this being present
# trips the "add KEY for the primary key" so that the
# AUTO_INCREMENT keyword is accepted by MySQL. SQLA 1.1 and
# greater the columns are just reorganized.
mysql_engine="InnoDB",
)
Table(
"a",
m2,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("x", BigInteger, primary_key=True, autoincrement=True),
)
ops = self._fixture(m1, m2, return_ops=True)
is_(ops.ops[0].ops[0].kw["autoincrement"], True)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,226 @@
import sqlalchemy as sa
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import Table
from alembic.util import sqla_compat
from ._autogen_fixtures import AutogenFixtureTest
from ... import testing
from ...testing import config
from ...testing import eq_
from ...testing import is_true
from ...testing import TestBase
class AutogenerateIdentityTest(AutogenFixtureTest, TestBase):
__requires__ = ("identity_columns",)
__backend__ = True
def test_add_identity_column(self):
m1 = MetaData()
m2 = MetaData()
Table("user", m1, Column("other", sa.Text))
Table(
"user",
m2,
Column("other", sa.Text),
Column(
"id",
Integer,
sa.Identity(start=5, increment=7),
primary_key=True,
),
)
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "add_column")
eq_(diffs[0][2], "user")
eq_(diffs[0][3].name, "id")
i = diffs[0][3].identity
is_true(isinstance(i, sa.Identity))
eq_(i.start, 5)
eq_(i.increment, 7)
def test_remove_identity_column(self):
m1 = MetaData()
m2 = MetaData()
Table(
"user",
m1,
Column(
"id",
Integer,
sa.Identity(start=2, increment=3),
primary_key=True,
),
)
Table("user", m2)
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "remove_column")
eq_(diffs[0][2], "user")
c = diffs[0][3]
eq_(c.name, "id")
is_true(isinstance(c.identity, sa.Identity))
eq_(c.identity.start, 2)
eq_(c.identity.increment, 3)
def test_no_change_identity_column(self):
m1 = MetaData()
m2 = MetaData()
for m in (m1, m2):
id_ = sa.Identity(start=2)
Table("user", m, Column("id", Integer, id_))
diffs = self._fixture(m1, m2)
eq_(diffs, [])
def test_dialect_kwargs_changes(self):
m1 = MetaData()
m2 = MetaData()
if sqla_compat.identity_has_dialect_kwargs:
args = {"oracle_on_null": True, "oracle_order": True}
else:
args = {"on_null": True, "order": True}
Table("user", m1, Column("id", Integer, sa.Identity(start=2)))
id_ = sa.Identity(start=2, **args)
Table("user", m2, Column("id", Integer, id_))
diffs = self._fixture(m1, m2)
if config.db.name == "oracle":
is_true(len(diffs), 1)
eq_(diffs[0][0][0], "modify_default")
else:
eq_(diffs, [])
@testing.combinations(
(None, dict(start=2)),
(dict(start=2), None),
(dict(start=2), dict(start=2, increment=7)),
(dict(always=False), dict(always=True)),
(
dict(start=1, minvalue=0, maxvalue=100, cycle=True),
dict(start=1, minvalue=0, maxvalue=100, cycle=False),
),
(
dict(start=10, increment=3, maxvalue=9999),
dict(start=10, increment=1, maxvalue=3333),
),
)
@config.requirements.identity_columns_alter
def test_change_identity(self, before, after):
arg_before = (sa.Identity(**before),) if before else ()
arg_after = (sa.Identity(**after),) if after else ()
m1 = MetaData()
m2 = MetaData()
Table(
"user",
m1,
Column("id", Integer, *arg_before),
Column("other", sa.Text),
)
Table(
"user",
m2,
Column("id", Integer, *arg_after),
Column("other", sa.Text),
)
diffs = self._fixture(m1, m2)
eq_(len(diffs[0]), 1)
diffs = diffs[0][0]
eq_(diffs[0], "modify_default")
eq_(diffs[2], "user")
eq_(diffs[3], "id")
old = diffs[5]
new = diffs[6]
def check(kw, idt):
if kw:
is_true(isinstance(idt, sa.Identity))
for k, v in kw.items():
eq_(getattr(idt, k), v)
else:
is_true(idt in (None, False))
check(before, old)
check(after, new)
def test_add_identity_to_column(self):
m1 = MetaData()
m2 = MetaData()
Table(
"user",
m1,
Column("id", Integer),
Column("other", sa.Text),
)
Table(
"user",
m2,
Column("id", Integer, sa.Identity(start=2, maxvalue=1000)),
Column("other", sa.Text),
)
diffs = self._fixture(m1, m2)
eq_(len(diffs[0]), 1)
diffs = diffs[0][0]
eq_(diffs[0], "modify_default")
eq_(diffs[2], "user")
eq_(diffs[3], "id")
eq_(diffs[5], None)
added = diffs[6]
is_true(isinstance(added, sa.Identity))
eq_(added.start, 2)
eq_(added.maxvalue, 1000)
def test_remove_identity_from_column(self):
m1 = MetaData()
m2 = MetaData()
Table(
"user",
m1,
Column("id", Integer, sa.Identity(start=2, maxvalue=1000)),
Column("other", sa.Text),
)
Table(
"user",
m2,
Column("id", Integer),
Column("other", sa.Text),
)
diffs = self._fixture(m1, m2)
eq_(len(diffs[0]), 1)
diffs = diffs[0][0]
eq_(diffs[0], "modify_default")
eq_(diffs[2], "user")
eq_(diffs[3], "id")
eq_(diffs[6], None)
removed = diffs[5]
is_true(isinstance(removed, sa.Identity))

View File

@@ -0,0 +1,364 @@
import io
from ...migration import MigrationContext
from ...testing import assert_raises
from ...testing import config
from ...testing import eq_
from ...testing import is_
from ...testing import is_false
from ...testing import is_not_
from ...testing import is_true
from ...testing import ne_
from ...testing.fixtures import TestBase
class MigrationTransactionTest(TestBase):
__backend__ = True
conn = None
def _fixture(self, opts):
self.conn = conn = config.db.connect()
if opts.get("as_sql", False):
self.context = MigrationContext.configure(
dialect=conn.dialect, opts=opts
)
self.context.output_buffer = (
self.context.impl.output_buffer
) = io.StringIO()
else:
self.context = MigrationContext.configure(
connection=conn, opts=opts
)
return self.context
def teardown_method(self):
if self.conn:
self.conn.close()
def test_proxy_transaction_rollback(self):
context = self._fixture(
{"transaction_per_migration": True, "transactional_ddl": True}
)
is_false(self.conn.in_transaction())
proxy = context.begin_transaction(_per_migration=True)
is_true(self.conn.in_transaction())
proxy.rollback()
is_false(self.conn.in_transaction())
def test_proxy_transaction_commit(self):
context = self._fixture(
{"transaction_per_migration": True, "transactional_ddl": True}
)
proxy = context.begin_transaction(_per_migration=True)
is_true(self.conn.in_transaction())
proxy.commit()
is_false(self.conn.in_transaction())
def test_proxy_transaction_contextmanager_commit(self):
context = self._fixture(
{"transaction_per_migration": True, "transactional_ddl": True}
)
proxy = context.begin_transaction(_per_migration=True)
is_true(self.conn.in_transaction())
with proxy:
pass
is_false(self.conn.in_transaction())
def test_proxy_transaction_contextmanager_rollback(self):
context = self._fixture(
{"transaction_per_migration": True, "transactional_ddl": True}
)
proxy = context.begin_transaction(_per_migration=True)
is_true(self.conn.in_transaction())
def go():
with proxy:
raise Exception("hi")
assert_raises(Exception, go)
is_false(self.conn.in_transaction())
def test_proxy_transaction_contextmanager_explicit_rollback(self):
context = self._fixture(
{"transaction_per_migration": True, "transactional_ddl": True}
)
proxy = context.begin_transaction(_per_migration=True)
is_true(self.conn.in_transaction())
with proxy:
is_true(self.conn.in_transaction())
proxy.rollback()
is_false(self.conn.in_transaction())
is_false(self.conn.in_transaction())
def test_proxy_transaction_contextmanager_explicit_commit(self):
context = self._fixture(
{"transaction_per_migration": True, "transactional_ddl": True}
)
proxy = context.begin_transaction(_per_migration=True)
is_true(self.conn.in_transaction())
with proxy:
is_true(self.conn.in_transaction())
proxy.commit()
is_false(self.conn.in_transaction())
is_false(self.conn.in_transaction())
def test_transaction_per_migration_transactional_ddl(self):
context = self._fixture(
{"transaction_per_migration": True, "transactional_ddl": True}
)
is_false(self.conn.in_transaction())
with context.begin_transaction():
is_false(self.conn.in_transaction())
with context.begin_transaction(_per_migration=True):
is_true(self.conn.in_transaction())
is_false(self.conn.in_transaction())
is_false(self.conn.in_transaction())
def test_transaction_per_migration_non_transactional_ddl(self):
context = self._fixture(
{"transaction_per_migration": True, "transactional_ddl": False}
)
is_false(self.conn.in_transaction())
with context.begin_transaction():
is_false(self.conn.in_transaction())
with context.begin_transaction(_per_migration=True):
is_true(self.conn.in_transaction())
is_false(self.conn.in_transaction())
is_false(self.conn.in_transaction())
def test_transaction_per_all_transactional_ddl(self):
context = self._fixture({"transactional_ddl": True})
is_false(self.conn.in_transaction())
with context.begin_transaction():
is_true(self.conn.in_transaction())
with context.begin_transaction(_per_migration=True):
is_true(self.conn.in_transaction())
is_true(self.conn.in_transaction())
is_false(self.conn.in_transaction())
def test_transaction_per_all_non_transactional_ddl(self):
context = self._fixture({"transactional_ddl": False})
is_false(self.conn.in_transaction())
with context.begin_transaction():
is_false(self.conn.in_transaction())
with context.begin_transaction(_per_migration=True):
is_true(self.conn.in_transaction())
is_false(self.conn.in_transaction())
is_false(self.conn.in_transaction())
def test_transaction_per_all_sqlmode(self):
context = self._fixture({"as_sql": True})
context.execute("step 1")
with context.begin_transaction():
context.execute("step 2")
with context.begin_transaction(_per_migration=True):
context.execute("step 3")
context.execute("step 4")
context.execute("step 5")
if context.impl.transactional_ddl:
self._assert_impl_steps(
"step 1",
"BEGIN",
"step 2",
"step 3",
"step 4",
"COMMIT",
"step 5",
)
else:
self._assert_impl_steps(
"step 1", "step 2", "step 3", "step 4", "step 5"
)
def test_transaction_per_migration_sqlmode(self):
context = self._fixture(
{"as_sql": True, "transaction_per_migration": True}
)
context.execute("step 1")
with context.begin_transaction():
context.execute("step 2")
with context.begin_transaction(_per_migration=True):
context.execute("step 3")
context.execute("step 4")
context.execute("step 5")
if context.impl.transactional_ddl:
self._assert_impl_steps(
"step 1",
"step 2",
"BEGIN",
"step 3",
"COMMIT",
"step 4",
"step 5",
)
else:
self._assert_impl_steps(
"step 1", "step 2", "step 3", "step 4", "step 5"
)
@config.requirements.autocommit_isolation
def test_autocommit_block(self):
context = self._fixture({"transaction_per_migration": True})
is_false(self.conn.in_transaction())
with context.begin_transaction():
is_false(self.conn.in_transaction())
with context.begin_transaction(_per_migration=True):
is_true(self.conn.in_transaction())
with context.autocommit_block():
# in 1.x, self.conn is separate due to the
# execution_options call. however for future they are the
# same connection and there is a "transaction" block
# despite autocommit
if self.is_sqlalchemy_future:
is_(context.connection, self.conn)
else:
is_not_(context.connection, self.conn)
is_false(self.conn.in_transaction())
eq_(
context.connection._execution_options[
"isolation_level"
],
"AUTOCOMMIT",
)
ne_(
context.connection._execution_options.get(
"isolation_level", None
),
"AUTOCOMMIT",
)
is_true(self.conn.in_transaction())
is_false(self.conn.in_transaction())
is_false(self.conn.in_transaction())
@config.requirements.autocommit_isolation
def test_autocommit_block_no_transaction(self):
context = self._fixture({"transaction_per_migration": True})
is_false(self.conn.in_transaction())
with context.autocommit_block():
is_true(context.connection.in_transaction())
# in 1.x, self.conn is separate due to the execution_options
# call. however for future they are the same connection and there
# is a "transaction" block despite autocommit
if self.is_sqlalchemy_future:
is_(context.connection, self.conn)
else:
is_not_(context.connection, self.conn)
is_false(self.conn.in_transaction())
eq_(
context.connection._execution_options["isolation_level"],
"AUTOCOMMIT",
)
ne_(
context.connection._execution_options.get("isolation_level", None),
"AUTOCOMMIT",
)
is_false(self.conn.in_transaction())
def test_autocommit_block_transactional_ddl_sqlmode(self):
context = self._fixture(
{
"transaction_per_migration": True,
"transactional_ddl": True,
"as_sql": True,
}
)
with context.begin_transaction():
context.execute("step 1")
with context.begin_transaction(_per_migration=True):
context.execute("step 2")
with context.autocommit_block():
context.execute("step 3")
context.execute("step 4")
context.execute("step 5")
self._assert_impl_steps(
"step 1",
"BEGIN",
"step 2",
"COMMIT",
"step 3",
"BEGIN",
"step 4",
"COMMIT",
"step 5",
)
def test_autocommit_block_nontransactional_ddl_sqlmode(self):
context = self._fixture(
{
"transaction_per_migration": True,
"transactional_ddl": False,
"as_sql": True,
}
)
with context.begin_transaction():
context.execute("step 1")
with context.begin_transaction(_per_migration=True):
context.execute("step 2")
with context.autocommit_block():
context.execute("step 3")
context.execute("step 4")
context.execute("step 5")
self._assert_impl_steps(
"step 1", "step 2", "step 3", "step 4", "step 5"
)
def _assert_impl_steps(self, *steps):
to_check = self.context.output_buffer.getvalue()
self.context.impl.output_buffer = buf = io.StringIO()
for step in steps:
if step == "BEGIN":
self.context.impl.emit_begin()
elif step == "COMMIT":
self.context.impl.emit_commit()
else:
self.context.impl._exec(step)
eq_(to_check, buf.getvalue())

View File

@@ -0,0 +1,42 @@
"""Test against the builders in the op.* module."""
from sqlalchemy import Column
from sqlalchemy import event
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy.sql import text
from ...testing.fixtures import AlterColRoundTripFixture
from ...testing.fixtures import TestBase
@event.listens_for(Table, "after_parent_attach")
def _add_cols(table, metadata):
if table.name == "tbl_with_auto_appended_column":
table.append_column(Column("bat", Integer))
class BackendAlterColumnTest(AlterColRoundTripFixture, TestBase):
__backend__ = True
def test_rename_column(self):
self._run_alter_col({}, {"name": "newname"})
def test_modify_type_int_str(self):
self._run_alter_col({"type": Integer()}, {"type": String(50)})
def test_add_server_default_int(self):
self._run_alter_col({"type": Integer}, {"server_default": text("5")})
def test_modify_server_default_int(self):
self._run_alter_col(
{"type": Integer, "server_default": text("2")},
{"server_default": text("5")},
)
def test_modify_nullable_to_non(self):
self._run_alter_col({}, {"nullable": False})
def test_modify_non_nullable_to_nullable(self):
self._run_alter_col({"nullable": False}, {"nullable": True})

View File

@@ -0,0 +1,126 @@
# testing/util.py
# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
import types
from typing import Union
from sqlalchemy.util import inspect_getfullargspec
from ..util import sqla_2
def flag_combinations(*combinations):
"""A facade around @testing.combinations() oriented towards boolean
keyword-based arguments.
Basically generates a nice looking identifier based on the keywords
and also sets up the argument names.
E.g.::
@testing.flag_combinations(
dict(lazy=False, passive=False),
dict(lazy=True, passive=False),
dict(lazy=False, passive=True),
dict(lazy=False, passive=True, raiseload=True),
)
would result in::
@testing.combinations(
('', False, False, False),
('lazy', True, False, False),
('lazy_passive', True, True, False),
('lazy_passive', True, True, True),
id_='iaaa',
argnames='lazy,passive,raiseload'
)
"""
from sqlalchemy.testing import config
keys = set()
for d in combinations:
keys.update(d)
keys = sorted(keys)
return config.combinations(
*[
("_".join(k for k in keys if d.get(k, False)),)
+ tuple(d.get(k, False) for k in keys)
for d in combinations
],
id_="i" + ("a" * len(keys)),
argnames=",".join(keys),
)
def resolve_lambda(__fn, **kw):
"""Given a no-arg lambda and a namespace, return a new lambda that
has all the values filled in.
This is used so that we can have module-level fixtures that
refer to instance-level variables using lambdas.
"""
pos_args = inspect_getfullargspec(__fn)[0]
pass_pos_args = {arg: kw.pop(arg) for arg in pos_args}
glb = dict(__fn.__globals__)
glb.update(kw)
new_fn = types.FunctionType(__fn.__code__, glb)
return new_fn(**pass_pos_args)
def metadata_fixture(ddl="function"):
"""Provide MetaData for a pytest fixture."""
from sqlalchemy.testing import config
from . import fixture_functions
def decorate(fn):
def run_ddl(self):
from sqlalchemy import schema
metadata = self.metadata = schema.MetaData()
try:
result = fn(self, metadata)
metadata.create_all(config.db)
# TODO:
# somehow get a per-function dml erase fixture here
yield result
finally:
metadata.drop_all(config.db)
return fixture_functions.fixture(scope=ddl)(run_ddl)
return decorate
def _safe_int(value: str) -> Union[int, str]:
try:
return int(value)
except:
return value
def testing_engine(url=None, options=None, future=False):
from sqlalchemy.testing import config
from sqlalchemy.testing.engines import testing_engine
if not future:
future = getattr(config._current.options, "future_engine", False)
if not sqla_2:
kw = {"future": future} if future else {}
else:
kw = {}
return testing_engine(url, options, **kw)

View File

@@ -0,0 +1,40 @@
# testing/warnings.py
# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import warnings
from sqlalchemy import exc as sa_exc
from ..util import sqla_14
def setup_filters():
"""Set global warning behavior for the test suite."""
warnings.resetwarnings()
warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning)
warnings.filterwarnings("error", category=sa_exc.SAWarning)
# some selected deprecations...
warnings.filterwarnings("error", category=DeprecationWarning)
if not sqla_14:
# 1.3 uses pkg_resources in PluginLoader
warnings.filterwarnings(
"ignore",
"pkg_resources is deprecated as an API",
DeprecationWarning,
)
try:
import pytest
except ImportError:
pass
else:
warnings.filterwarnings(
"once", category=pytest.PytestDeprecationWarning
)