1
0
mirror of https://gitlab.com/MoonTestUse1/AdministrationItDepartmens.git synced 2025-08-14 00:25:46 +02:00

Все подряд

This commit is contained in:
MoonTestUse1
2024-12-31 02:37:57 +06:00
parent 8e53bb6cb2
commit d5780b2eab
3258 changed files with 1087440 additions and 268 deletions

View File

@@ -0,0 +1,4 @@
from . import context
from . import op
__version__ = "1.13.1"

View File

@@ -0,0 +1,4 @@
from .config import main
if __name__ == "__main__":
main(prog="alembic")

View File

@@ -0,0 +1,10 @@
from .api import _render_migration_diffs as _render_migration_diffs
from .api import compare_metadata as compare_metadata
from .api import produce_migrations as produce_migrations
from .api import render_python_code as render_python_code
from .api import RevisionContext as RevisionContext
from .compare import _produce_net_changes as _produce_net_changes
from .compare import comparators as comparators
from .render import render_op_text as render_op_text
from .render import renderers as renderers
from .rewriter import Rewriter as Rewriter

View File

@@ -0,0 +1,650 @@
from __future__ import annotations
import contextlib
from typing import Any
from typing import Dict
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import inspect
from . import compare
from . import render
from .. import util
from ..operations import ops
from ..util import sqla_compat
"""Provide the 'autogenerate' feature which can produce migration operations
automatically."""
if TYPE_CHECKING:
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from sqlalchemy.engine import Inspector
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import SchemaItem
from sqlalchemy.sql.schema import Table
from ..config import Config
from ..operations.ops import DowngradeOps
from ..operations.ops import MigrationScript
from ..operations.ops import UpgradeOps
from ..runtime.environment import NameFilterParentNames
from ..runtime.environment import NameFilterType
from ..runtime.environment import ProcessRevisionDirectiveFn
from ..runtime.environment import RenderItemFn
from ..runtime.migration import MigrationContext
from ..script.base import Script
from ..script.base import ScriptDirectory
from ..script.revision import _GetRevArg
def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any:
"""Compare a database schema to that given in a
:class:`~sqlalchemy.schema.MetaData` instance.
The database connection is presented in the context
of a :class:`.MigrationContext` object, which
provides database connectivity as well as optional
comparison functions to use for datatypes and
server defaults - see the "autogenerate" arguments
at :meth:`.EnvironmentContext.configure`
for details on these.
The return format is a list of "diff" directives,
each representing individual differences::
from alembic.migration import MigrationContext
from alembic.autogenerate import compare_metadata
from sqlalchemy import (
create_engine,
MetaData,
Column,
Integer,
String,
Table,
text,
)
import pprint
engine = create_engine("sqlite://")
with engine.begin() as conn:
conn.execute(
text(
'''
create table foo (
id integer not null primary key,
old_data varchar,
x integer
)
'''
)
)
conn.execute(text("create table bar (data varchar)"))
metadata = MetaData()
Table(
"foo",
metadata,
Column("id", Integer, primary_key=True),
Column("data", Integer),
Column("x", Integer, nullable=False),
)
Table("bat", metadata, Column("info", String))
mc = MigrationContext.configure(engine.connect())
diff = compare_metadata(mc, metadata)
pprint.pprint(diff, indent=2, width=20)
Output::
[
(
"add_table",
Table(
"bat",
MetaData(),
Column("info", String(), table=<bat>),
schema=None,
),
),
(
"remove_table",
Table(
"bar",
MetaData(),
Column("data", VARCHAR(), table=<bar>),
schema=None,
),
),
(
"add_column",
None,
"foo",
Column("data", Integer(), table=<foo>),
),
[
(
"modify_nullable",
None,
"foo",
"x",
{
"existing_comment": None,
"existing_server_default": False,
"existing_type": INTEGER(),
},
True,
False,
)
],
(
"remove_column",
None,
"foo",
Column("old_data", VARCHAR(), table=<foo>),
),
]
:param context: a :class:`.MigrationContext`
instance.
:param metadata: a :class:`~sqlalchemy.schema.MetaData`
instance.
.. seealso::
:func:`.produce_migrations` - produces a :class:`.MigrationScript`
structure based on metadata comparison.
"""
migration_script = produce_migrations(context, metadata)
assert migration_script.upgrade_ops is not None
return migration_script.upgrade_ops.as_diffs()
def produce_migrations(
context: MigrationContext, metadata: MetaData
) -> MigrationScript:
"""Produce a :class:`.MigrationScript` structure based on schema
comparison.
This function does essentially what :func:`.compare_metadata` does,
but then runs the resulting list of diffs to produce the full
:class:`.MigrationScript` object. For an example of what this looks like,
see the example in :ref:`customizing_revision`.
.. seealso::
:func:`.compare_metadata` - returns more fundamental "diff"
data from comparing a schema.
"""
autogen_context = AutogenContext(context, metadata=metadata)
migration_script = ops.MigrationScript(
rev_id=None,
upgrade_ops=ops.UpgradeOps([]),
downgrade_ops=ops.DowngradeOps([]),
)
compare._populate_migration_script(autogen_context, migration_script)
return migration_script
def render_python_code(
up_or_down_op: Union[UpgradeOps, DowngradeOps],
sqlalchemy_module_prefix: str = "sa.",
alembic_module_prefix: str = "op.",
render_as_batch: bool = False,
imports: Sequence[str] = (),
render_item: Optional[RenderItemFn] = None,
migration_context: Optional[MigrationContext] = None,
user_module_prefix: Optional[str] = None,
) -> str:
"""Render Python code given an :class:`.UpgradeOps` or
:class:`.DowngradeOps` object.
This is a convenience function that can be used to test the
autogenerate output of a user-defined :class:`.MigrationScript` structure.
:param up_or_down_op: :class:`.UpgradeOps` or :class:`.DowngradeOps` object
:param sqlalchemy_module_prefix: module prefix for SQLAlchemy objects
:param alembic_module_prefix: module prefix for Alembic constructs
:param render_as_batch: use "batch operations" style for rendering
:param imports: sequence of import symbols to add
:param render_item: callable to render items
:param migration_context: optional :class:`.MigrationContext`
:param user_module_prefix: optional string prefix for user-defined types
.. versionadded:: 1.11.0
"""
opts = {
"sqlalchemy_module_prefix": sqlalchemy_module_prefix,
"alembic_module_prefix": alembic_module_prefix,
"render_item": render_item,
"render_as_batch": render_as_batch,
"user_module_prefix": user_module_prefix,
}
if migration_context is None:
from ..runtime.migration import MigrationContext
from sqlalchemy.engine.default import DefaultDialect
migration_context = MigrationContext.configure(
dialect=DefaultDialect()
)
autogen_context = AutogenContext(migration_context, opts=opts)
autogen_context.imports = set(imports)
return render._indent(
render._render_cmd_body(up_or_down_op, autogen_context)
)
def _render_migration_diffs(
context: MigrationContext, template_args: Dict[Any, Any]
) -> None:
"""legacy, used by test_autogen_composition at the moment"""
autogen_context = AutogenContext(context)
upgrade_ops = ops.UpgradeOps([])
compare._produce_net_changes(autogen_context, upgrade_ops)
migration_script = ops.MigrationScript(
rev_id=None,
upgrade_ops=upgrade_ops,
downgrade_ops=upgrade_ops.reverse(),
)
render._render_python_into_templatevars(
autogen_context, migration_script, template_args
)
class AutogenContext:
"""Maintains configuration and state that's specific to an
autogenerate operation."""
metadata: Optional[MetaData] = None
"""The :class:`~sqlalchemy.schema.MetaData` object
representing the destination.
This object is the one that is passed within ``env.py``
to the :paramref:`.EnvironmentContext.configure.target_metadata`
parameter. It represents the structure of :class:`.Table` and other
objects as stated in the current database model, and represents the
destination structure for the database being examined.
While the :class:`~sqlalchemy.schema.MetaData` object is primarily
known as a collection of :class:`~sqlalchemy.schema.Table` objects,
it also has an :attr:`~sqlalchemy.schema.MetaData.info` dictionary
that may be used by end-user schemes to store additional schema-level
objects that are to be compared in custom autogeneration schemes.
"""
connection: Optional[Connection] = None
"""The :class:`~sqlalchemy.engine.base.Connection` object currently
connected to the database backend being compared.
This is obtained from the :attr:`.MigrationContext.bind` and is
ultimately set up in the ``env.py`` script.
"""
dialect: Optional[Dialect] = None
"""The :class:`~sqlalchemy.engine.Dialect` object currently in use.
This is normally obtained from the
:attr:`~sqlalchemy.engine.base.Connection.dialect` attribute.
"""
imports: Set[str] = None # type: ignore[assignment]
"""A ``set()`` which contains string Python import directives.
The directives are to be rendered into the ``${imports}`` section
of a script template. The set is normally empty and can be modified
within hooks such as the
:paramref:`.EnvironmentContext.configure.render_item` hook.
.. seealso::
:ref:`autogen_render_types`
"""
migration_context: MigrationContext = None # type: ignore[assignment]
"""The :class:`.MigrationContext` established by the ``env.py`` script."""
def __init__(
self,
migration_context: MigrationContext,
metadata: Optional[MetaData] = None,
opts: Optional[Dict[str, Any]] = None,
autogenerate: bool = True,
) -> None:
if (
autogenerate
and migration_context is not None
and migration_context.as_sql
):
raise util.CommandError(
"autogenerate can't use as_sql=True as it prevents querying "
"the database for schema information"
)
if opts is None:
opts = migration_context.opts
self.metadata = metadata = (
opts.get("target_metadata", None) if metadata is None else metadata
)
if (
autogenerate
and metadata is None
and migration_context is not None
and migration_context.script is not None
):
raise util.CommandError(
"Can't proceed with --autogenerate option; environment "
"script %s does not provide "
"a MetaData object or sequence of objects to the context."
% (migration_context.script.env_py_location)
)
include_object = opts.get("include_object", None)
include_name = opts.get("include_name", None)
object_filters = []
name_filters = []
if include_object:
object_filters.append(include_object)
if include_name:
name_filters.append(include_name)
self._object_filters = object_filters
self._name_filters = name_filters
self.migration_context = migration_context
if self.migration_context is not None:
self.connection = self.migration_context.bind
self.dialect = self.migration_context.dialect
self.imports = set()
self.opts: Dict[str, Any] = opts
self._has_batch: bool = False
@util.memoized_property
def inspector(self) -> Inspector:
if self.connection is None:
raise TypeError(
"can't return inspector as this "
"AutogenContext has no database connection"
)
return inspect(self.connection)
@contextlib.contextmanager
def _within_batch(self) -> Iterator[None]:
self._has_batch = True
yield
self._has_batch = False
def run_name_filters(
self,
name: Optional[str],
type_: NameFilterType,
parent_names: NameFilterParentNames,
) -> bool:
"""Run the context's name filters and return True if the targets
should be part of the autogenerate operation.
This method should be run for every kind of name encountered within the
reflection side of an autogenerate operation, giving the environment
the chance to filter what names should be reflected as database
objects. The filters here are produced directly via the
:paramref:`.EnvironmentContext.configure.include_name` parameter.
"""
if "schema_name" in parent_names:
if type_ == "table":
table_name = name
else:
table_name = parent_names.get("table_name", None)
if table_name:
schema_name = parent_names["schema_name"]
if schema_name:
parent_names["schema_qualified_table_name"] = "%s.%s" % (
schema_name,
table_name,
)
else:
parent_names["schema_qualified_table_name"] = table_name
for fn in self._name_filters:
if not fn(name, type_, parent_names):
return False
else:
return True
def run_object_filters(
self,
object_: SchemaItem,
name: sqla_compat._ConstraintName,
type_: NameFilterType,
reflected: bool,
compare_to: Optional[SchemaItem],
) -> bool:
"""Run the context's object filters and return True if the targets
should be part of the autogenerate operation.
This method should be run for every kind of object encountered within
an autogenerate operation, giving the environment the chance
to filter what objects should be included in the comparison.
The filters here are produced directly via the
:paramref:`.EnvironmentContext.configure.include_object` parameter.
"""
for fn in self._object_filters:
if not fn(object_, name, type_, reflected, compare_to):
return False
else:
return True
run_filters = run_object_filters
@util.memoized_property
def sorted_tables(self) -> List[Table]:
"""Return an aggregate of the :attr:`.MetaData.sorted_tables`
collection(s).
For a sequence of :class:`.MetaData` objects, this
concatenates the :attr:`.MetaData.sorted_tables` collection
for each individual :class:`.MetaData` in the order of the
sequence. It does **not** collate the sorted tables collections.
"""
result = []
for m in util.to_list(self.metadata):
result.extend(m.sorted_tables)
return result
@util.memoized_property
def table_key_to_table(self) -> Dict[str, Table]:
"""Return an aggregate of the :attr:`.MetaData.tables` dictionaries.
The :attr:`.MetaData.tables` collection is a dictionary of table key
to :class:`.Table`; this method aggregates the dictionary across
multiple :class:`.MetaData` objects into one dictionary.
Duplicate table keys are **not** supported; if two :class:`.MetaData`
objects contain the same table key, an exception is raised.
"""
result: Dict[str, Table] = {}
for m in util.to_list(self.metadata):
intersect = set(result).intersection(set(m.tables))
if intersect:
raise ValueError(
"Duplicate table keys across multiple "
"MetaData objects: %s"
% (", ".join('"%s"' % key for key in sorted(intersect)))
)
result.update(m.tables)
return result
class RevisionContext:
"""Maintains configuration and state that's specific to a revision
file generation operation."""
generated_revisions: List[MigrationScript]
process_revision_directives: Optional[ProcessRevisionDirectiveFn]
def __init__(
self,
config: Config,
script_directory: ScriptDirectory,
command_args: Dict[str, Any],
process_revision_directives: Optional[
ProcessRevisionDirectiveFn
] = None,
) -> None:
self.config = config
self.script_directory = script_directory
self.command_args = command_args
self.process_revision_directives = process_revision_directives
self.template_args = {
"config": config # Let templates use config for
# e.g. multiple databases
}
self.generated_revisions = [self._default_revision()]
def _to_script(
self, migration_script: MigrationScript
) -> Optional[Script]:
template_args: Dict[str, Any] = self.template_args.copy()
if getattr(migration_script, "_needs_render", False):
autogen_context = self._last_autogen_context
# clear out existing imports if we are doing multiple
# renders
autogen_context.imports = set()
if migration_script.imports:
autogen_context.imports.update(migration_script.imports)
render._render_python_into_templatevars(
autogen_context, migration_script, template_args
)
assert migration_script.rev_id is not None
return self.script_directory.generate_revision(
migration_script.rev_id,
migration_script.message,
refresh=True,
head=migration_script.head,
splice=migration_script.splice,
branch_labels=migration_script.branch_label,
version_path=migration_script.version_path,
depends_on=migration_script.depends_on,
**template_args,
)
def run_autogenerate(
self, rev: _GetRevArg, migration_context: MigrationContext
) -> None:
self._run_environment(rev, migration_context, True)
def run_no_autogenerate(
self, rev: _GetRevArg, migration_context: MigrationContext
) -> None:
self._run_environment(rev, migration_context, False)
def _run_environment(
self,
rev: _GetRevArg,
migration_context: MigrationContext,
autogenerate: bool,
) -> None:
if autogenerate:
if self.command_args["sql"]:
raise util.CommandError(
"Using --sql with --autogenerate does not make any sense"
)
if set(self.script_directory.get_revisions(rev)) != set(
self.script_directory.get_revisions("heads")
):
raise util.CommandError("Target database is not up to date.")
upgrade_token = migration_context.opts["upgrade_token"]
downgrade_token = migration_context.opts["downgrade_token"]
migration_script = self.generated_revisions[-1]
if not getattr(migration_script, "_needs_render", False):
migration_script.upgrade_ops_list[-1].upgrade_token = upgrade_token
migration_script.downgrade_ops_list[
-1
].downgrade_token = downgrade_token
migration_script._needs_render = True
else:
migration_script._upgrade_ops.append(
ops.UpgradeOps([], upgrade_token=upgrade_token)
)
migration_script._downgrade_ops.append(
ops.DowngradeOps([], downgrade_token=downgrade_token)
)
autogen_context = AutogenContext(
migration_context, autogenerate=autogenerate
)
self._last_autogen_context: AutogenContext = autogen_context
if autogenerate:
compare._populate_migration_script(
autogen_context, migration_script
)
if self.process_revision_directives:
self.process_revision_directives(
migration_context, rev, self.generated_revisions
)
hook = migration_context.opts["process_revision_directives"]
if hook:
hook(migration_context, rev, self.generated_revisions)
for migration_script in self.generated_revisions:
migration_script._needs_render = True
def _default_revision(self) -> MigrationScript:
command_args: Dict[str, Any] = self.command_args
op = ops.MigrationScript(
rev_id=command_args["rev_id"] or util.rev_id(),
message=command_args["message"],
upgrade_ops=ops.UpgradeOps([]),
downgrade_ops=ops.DowngradeOps([]),
head=command_args["head"],
splice=command_args["splice"],
branch_label=command_args["branch_label"],
version_path=command_args["version_path"],
depends_on=command_args["depends_on"],
)
return op
def generate_scripts(self) -> Iterator[Optional[Script]]:
for generated_revision in self.generated_revisions:
yield self._to_script(generated_revision)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,240 @@
from __future__ import annotations
from typing import Any
from typing import Callable
from typing import Iterator
from typing import List
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
from .. import util
from ..operations import ops
if TYPE_CHECKING:
from ..operations.ops import AddColumnOp
from ..operations.ops import AlterColumnOp
from ..operations.ops import CreateTableOp
from ..operations.ops import DowngradeOps
from ..operations.ops import MigrateOperation
from ..operations.ops import MigrationScript
from ..operations.ops import ModifyTableOps
from ..operations.ops import OpContainer
from ..operations.ops import UpgradeOps
from ..runtime.migration import MigrationContext
from ..script.revision import _GetRevArg
ProcessRevisionDirectiveFn = Callable[
["MigrationContext", "_GetRevArg", List["MigrationScript"]], None
]
class Rewriter:
"""A helper object that allows easy 'rewriting' of ops streams.
The :class:`.Rewriter` object is intended to be passed along
to the
:paramref:`.EnvironmentContext.configure.process_revision_directives`
parameter in an ``env.py`` script. Once constructed, any number
of "rewrites" functions can be associated with it, which will be given
the opportunity to modify the structure without having to have explicit
knowledge of the overall structure.
The function is passed the :class:`.MigrationContext` object and
``revision`` tuple that are passed to the :paramref:`.Environment
Context.configure.process_revision_directives` function normally,
and the third argument is an individual directive of the type
noted in the decorator. The function has the choice of returning
a single op directive, which normally can be the directive that
was actually passed, or a new directive to replace it, or a list
of zero or more directives to replace it.
.. seealso::
:ref:`autogen_rewriter` - usage example
"""
_traverse = util.Dispatcher()
_chained: Tuple[Union[ProcessRevisionDirectiveFn, Rewriter], ...] = ()
def __init__(self) -> None:
self.dispatch = util.Dispatcher()
def chain(
self,
other: Union[
ProcessRevisionDirectiveFn,
Rewriter,
],
) -> Rewriter:
"""Produce a "chain" of this :class:`.Rewriter` to another.
This allows two or more rewriters to operate serially on a stream,
e.g.::
writer1 = autogenerate.Rewriter()
writer2 = autogenerate.Rewriter()
@writer1.rewrites(ops.AddColumnOp)
def add_column_nullable(context, revision, op):
op.column.nullable = True
return op
@writer2.rewrites(ops.AddColumnOp)
def add_column_idx(context, revision, op):
idx_op = ops.CreateIndexOp(
"ixc", op.table_name, [op.column.name]
)
return [op, idx_op]
writer = writer1.chain(writer2)
:param other: a :class:`.Rewriter` instance
:return: a new :class:`.Rewriter` that will run the operations
of this writer, then the "other" writer, in succession.
"""
wr = self.__class__.__new__(self.__class__)
wr.__dict__.update(self.__dict__)
wr._chained += (other,)
return wr
def rewrites(
self,
operator: Union[
Type[AddColumnOp],
Type[MigrateOperation],
Type[AlterColumnOp],
Type[CreateTableOp],
Type[ModifyTableOps],
],
) -> Callable[..., Any]:
"""Register a function as rewriter for a given type.
The function should receive three arguments, which are
the :class:`.MigrationContext`, a ``revision`` tuple, and
an op directive of the type indicated. E.g.::
@writer1.rewrites(ops.AddColumnOp)
def add_column_nullable(context, revision, op):
op.column.nullable = True
return op
"""
return self.dispatch.dispatch_for(operator)
def _rewrite(
self,
context: MigrationContext,
revision: _GetRevArg,
directive: MigrateOperation,
) -> Iterator[MigrateOperation]:
try:
_rewriter = self.dispatch.dispatch(directive)
except ValueError:
_rewriter = None
yield directive
else:
if self in directive._mutations:
yield directive
else:
for r_directive in util.to_list(
_rewriter(context, revision, directive), []
):
r_directive._mutations = r_directive._mutations.union(
[self]
)
yield r_directive
def __call__(
self,
context: MigrationContext,
revision: _GetRevArg,
directives: List[MigrationScript],
) -> None:
self.process_revision_directives(context, revision, directives)
for process_revision_directives in self._chained:
process_revision_directives(context, revision, directives)
@_traverse.dispatch_for(ops.MigrationScript)
def _traverse_script(
self,
context: MigrationContext,
revision: _GetRevArg,
directive: MigrationScript,
) -> None:
upgrade_ops_list: List[UpgradeOps] = []
for upgrade_ops in directive.upgrade_ops_list:
ret = self._traverse_for(context, revision, upgrade_ops)
if len(ret) != 1:
raise ValueError(
"Can only return single object for UpgradeOps traverse"
)
upgrade_ops_list.append(ret[0])
directive.upgrade_ops = upgrade_ops_list # type: ignore
downgrade_ops_list: List[DowngradeOps] = []
for downgrade_ops in directive.downgrade_ops_list:
ret = self._traverse_for(context, revision, downgrade_ops)
if len(ret) != 1:
raise ValueError(
"Can only return single object for DowngradeOps traverse"
)
downgrade_ops_list.append(ret[0])
directive.downgrade_ops = downgrade_ops_list # type: ignore
@_traverse.dispatch_for(ops.OpContainer)
def _traverse_op_container(
self,
context: MigrationContext,
revision: _GetRevArg,
directive: OpContainer,
) -> None:
self._traverse_list(context, revision, directive.ops)
@_traverse.dispatch_for(ops.MigrateOperation)
def _traverse_any_directive(
self,
context: MigrationContext,
revision: _GetRevArg,
directive: MigrateOperation,
) -> None:
pass
def _traverse_for(
self,
context: MigrationContext,
revision: _GetRevArg,
directive: MigrateOperation,
) -> Any:
directives = list(self._rewrite(context, revision, directive))
for directive in directives:
traverser = self._traverse.dispatch(directive)
traverser(self, context, revision, directive)
return directives
def _traverse_list(
self,
context: MigrationContext,
revision: _GetRevArg,
directives: Any,
) -> None:
dest = []
for directive in directives:
dest.extend(self._traverse_for(context, revision, directive))
directives[:] = dest
def process_revision_directives(
self,
context: MigrationContext,
revision: _GetRevArg,
directives: List[MigrationScript],
) -> None:
self._traverse_list(context, revision, directives)

View File

@@ -0,0 +1,749 @@
# mypy: allow-untyped-defs, allow-untyped-calls
from __future__ import annotations
import os
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from . import autogenerate as autogen
from . import util
from .runtime.environment import EnvironmentContext
from .script import ScriptDirectory
if TYPE_CHECKING:
from alembic.config import Config
from alembic.script.base import Script
from alembic.script.revision import _RevIdType
from .runtime.environment import ProcessRevisionDirectiveFn
def list_templates(config: Config) -> None:
"""List available templates.
:param config: a :class:`.Config` object.
"""
config.print_stdout("Available templates:\n")
for tempname in os.listdir(config.get_template_directory()):
with open(
os.path.join(config.get_template_directory(), tempname, "README")
) as readme:
synopsis = next(readme).rstrip()
config.print_stdout("%s - %s", tempname, synopsis)
config.print_stdout("\nTemplates are used via the 'init' command, e.g.:")
config.print_stdout("\n alembic init --template generic ./scripts")
def init(
config: Config,
directory: str,
template: str = "generic",
package: bool = False,
) -> None:
"""Initialize a new scripts directory.
:param config: a :class:`.Config` object.
:param directory: string path of the target directory
:param template: string name of the migration environment template to
use.
:param package: when True, write ``__init__.py`` files into the
environment location as well as the versions/ location.
"""
if os.access(directory, os.F_OK) and os.listdir(directory):
raise util.CommandError(
"Directory %s already exists and is not empty" % directory
)
template_dir = os.path.join(config.get_template_directory(), template)
if not os.access(template_dir, os.F_OK):
raise util.CommandError("No such template %r" % template)
if not os.access(directory, os.F_OK):
with util.status(
f"Creating directory {os.path.abspath(directory)!r}",
**config.messaging_opts,
):
os.makedirs(directory)
versions = os.path.join(directory, "versions")
with util.status(
f"Creating directory {os.path.abspath(versions)!r}",
**config.messaging_opts,
):
os.makedirs(versions)
script = ScriptDirectory(directory)
config_file: str | None = None
for file_ in os.listdir(template_dir):
file_path = os.path.join(template_dir, file_)
if file_ == "alembic.ini.mako":
assert config.config_file_name is not None
config_file = os.path.abspath(config.config_file_name)
if os.access(config_file, os.F_OK):
util.msg(
f"File {config_file!r} already exists, skipping",
**config.messaging_opts,
)
else:
script._generate_template(
file_path, config_file, script_location=directory
)
elif os.path.isfile(file_path):
output_file = os.path.join(directory, file_)
script._copy_file(file_path, output_file)
if package:
for path in [
os.path.join(os.path.abspath(directory), "__init__.py"),
os.path.join(os.path.abspath(versions), "__init__.py"),
]:
with util.status(f"Adding {path!r}", **config.messaging_opts):
with open(path, "w"):
pass
assert config_file is not None
util.msg(
"Please edit configuration/connection/logging "
f"settings in {config_file!r} before proceeding.",
**config.messaging_opts,
)
def revision(
config: Config,
message: Optional[str] = None,
autogenerate: bool = False,
sql: bool = False,
head: str = "head",
splice: bool = False,
branch_label: Optional[_RevIdType] = None,
version_path: Optional[str] = None,
rev_id: Optional[str] = None,
depends_on: Optional[str] = None,
process_revision_directives: Optional[ProcessRevisionDirectiveFn] = None,
) -> Union[Optional[Script], List[Optional[Script]]]:
"""Create a new revision file.
:param config: a :class:`.Config` object.
:param message: string message to apply to the revision; this is the
``-m`` option to ``alembic revision``.
:param autogenerate: whether or not to autogenerate the script from
the database; this is the ``--autogenerate`` option to
``alembic revision``.
:param sql: whether to dump the script out as a SQL string; when specified,
the script is dumped to stdout. This is the ``--sql`` option to
``alembic revision``.
:param head: head revision to build the new revision upon as a parent;
this is the ``--head`` option to ``alembic revision``.
:param splice: whether or not the new revision should be made into a
new head of its own; is required when the given ``head`` is not itself
a head. This is the ``--splice`` option to ``alembic revision``.
:param branch_label: string label to apply to the branch; this is the
``--branch-label`` option to ``alembic revision``.
:param version_path: string symbol identifying a specific version path
from the configuration; this is the ``--version-path`` option to
``alembic revision``.
:param rev_id: optional revision identifier to use instead of having
one generated; this is the ``--rev-id`` option to ``alembic revision``.
:param depends_on: optional list of "depends on" identifiers; this is the
``--depends-on`` option to ``alembic revision``.
:param process_revision_directives: this is a callable that takes the
same form as the callable described at
:paramref:`.EnvironmentContext.configure.process_revision_directives`;
will be applied to the structure generated by the revision process
where it can be altered programmatically. Note that unlike all
the other parameters, this option is only available via programmatic
use of :func:`.command.revision`
"""
script_directory = ScriptDirectory.from_config(config)
command_args = dict(
message=message,
autogenerate=autogenerate,
sql=sql,
head=head,
splice=splice,
branch_label=branch_label,
version_path=version_path,
rev_id=rev_id,
depends_on=depends_on,
)
revision_context = autogen.RevisionContext(
config,
script_directory,
command_args,
process_revision_directives=process_revision_directives,
)
environment = util.asbool(config.get_main_option("revision_environment"))
if autogenerate:
environment = True
if sql:
raise util.CommandError(
"Using --sql with --autogenerate does not make any sense"
)
def retrieve_migrations(rev, context):
revision_context.run_autogenerate(rev, context)
return []
elif environment:
def retrieve_migrations(rev, context):
revision_context.run_no_autogenerate(rev, context)
return []
elif sql:
raise util.CommandError(
"Using --sql with the revision command when "
"revision_environment is not configured does not make any sense"
)
if environment:
with EnvironmentContext(
config,
script_directory,
fn=retrieve_migrations,
as_sql=sql,
template_args=revision_context.template_args,
revision_context=revision_context,
):
script_directory.run_env()
# the revision_context now has MigrationScript structure(s) present.
# these could theoretically be further processed / rewritten *here*,
# in addition to the hooks present within each run_migrations() call,
# or at the end of env.py run_migrations_online().
scripts = [script for script in revision_context.generate_scripts()]
if len(scripts) == 1:
return scripts[0]
else:
return scripts
def check(config: "Config") -> None:
"""Check if revision command with autogenerate has pending upgrade ops.
:param config: a :class:`.Config` object.
.. versionadded:: 1.9.0
"""
script_directory = ScriptDirectory.from_config(config)
command_args = dict(
message=None,
autogenerate=True,
sql=False,
head="head",
splice=False,
branch_label=None,
version_path=None,
rev_id=None,
depends_on=None,
)
revision_context = autogen.RevisionContext(
config,
script_directory,
command_args,
)
def retrieve_migrations(rev, context):
revision_context.run_autogenerate(rev, context)
return []
with EnvironmentContext(
config,
script_directory,
fn=retrieve_migrations,
as_sql=False,
template_args=revision_context.template_args,
revision_context=revision_context,
):
script_directory.run_env()
# the revision_context now has MigrationScript structure(s) present.
migration_script = revision_context.generated_revisions[-1]
diffs = []
for upgrade_ops in migration_script.upgrade_ops_list:
diffs.extend(upgrade_ops.as_diffs())
if diffs:
raise util.AutogenerateDiffsDetected(
f"New upgrade operations detected: {diffs}"
)
else:
config.print_stdout("No new upgrade operations detected.")
def merge(
config: Config,
revisions: _RevIdType,
message: Optional[str] = None,
branch_label: Optional[_RevIdType] = None,
rev_id: Optional[str] = None,
) -> Optional[Script]:
"""Merge two revisions together. Creates a new migration file.
:param config: a :class:`.Config` instance
:param message: string message to apply to the revision
:param branch_label: string label name to apply to the new revision
:param rev_id: hardcoded revision identifier instead of generating a new
one.
.. seealso::
:ref:`branches`
"""
script = ScriptDirectory.from_config(config)
template_args = {
"config": config # Let templates use config for
# e.g. multiple databases
}
environment = util.asbool(config.get_main_option("revision_environment"))
if environment:
def nothing(rev, context):
return []
with EnvironmentContext(
config,
script,
fn=nothing,
as_sql=False,
template_args=template_args,
):
script.run_env()
return script.generate_revision(
rev_id or util.rev_id(),
message,
refresh=True,
head=revisions,
branch_labels=branch_label,
**template_args, # type:ignore[arg-type]
)
def upgrade(
config: Config,
revision: str,
sql: bool = False,
tag: Optional[str] = None,
) -> None:
"""Upgrade to a later version.
:param config: a :class:`.Config` instance.
:param revision: string revision target or range for --sql mode
:param sql: if True, use ``--sql`` mode
:param tag: an arbitrary "tag" that can be intercepted by custom
``env.py`` scripts via the :meth:`.EnvironmentContext.get_tag_argument`
method.
"""
script = ScriptDirectory.from_config(config)
starting_rev = None
if ":" in revision:
if not sql:
raise util.CommandError("Range revision not allowed")
starting_rev, revision = revision.split(":", 2)
def upgrade(rev, context):
return script._upgrade_revs(revision, rev)
with EnvironmentContext(
config,
script,
fn=upgrade,
as_sql=sql,
starting_rev=starting_rev,
destination_rev=revision,
tag=tag,
):
script.run_env()
def downgrade(
config: Config,
revision: str,
sql: bool = False,
tag: Optional[str] = None,
) -> None:
"""Revert to a previous version.
:param config: a :class:`.Config` instance.
:param revision: string revision target or range for --sql mode
:param sql: if True, use ``--sql`` mode
:param tag: an arbitrary "tag" that can be intercepted by custom
``env.py`` scripts via the :meth:`.EnvironmentContext.get_tag_argument`
method.
"""
script = ScriptDirectory.from_config(config)
starting_rev = None
if ":" in revision:
if not sql:
raise util.CommandError("Range revision not allowed")
starting_rev, revision = revision.split(":", 2)
elif sql:
raise util.CommandError(
"downgrade with --sql requires <fromrev>:<torev>"
)
def downgrade(rev, context):
return script._downgrade_revs(revision, rev)
with EnvironmentContext(
config,
script,
fn=downgrade,
as_sql=sql,
starting_rev=starting_rev,
destination_rev=revision,
tag=tag,
):
script.run_env()
def show(config, rev):
"""Show the revision(s) denoted by the given symbol.
:param config: a :class:`.Config` instance.
:param revision: string revision target
"""
script = ScriptDirectory.from_config(config)
if rev == "current":
def show_current(rev, context):
for sc in script.get_revisions(rev):
config.print_stdout(sc.log_entry)
return []
with EnvironmentContext(config, script, fn=show_current):
script.run_env()
else:
for sc in script.get_revisions(rev):
config.print_stdout(sc.log_entry)
def history(
config: Config,
rev_range: Optional[str] = None,
verbose: bool = False,
indicate_current: bool = False,
) -> None:
"""List changeset scripts in chronological order.
:param config: a :class:`.Config` instance.
:param rev_range: string revision range
:param verbose: output in verbose mode.
:param indicate_current: indicate current revision.
"""
base: Optional[str]
head: Optional[str]
script = ScriptDirectory.from_config(config)
if rev_range is not None:
if ":" not in rev_range:
raise util.CommandError(
"History range requires [start]:[end], " "[start]:, or :[end]"
)
base, head = rev_range.strip().split(":")
else:
base = head = None
environment = (
util.asbool(config.get_main_option("revision_environment"))
or indicate_current
)
def _display_history(config, script, base, head, currents=()):
for sc in script.walk_revisions(
base=base or "base", head=head or "heads"
):
if indicate_current:
sc._db_current_indicator = sc.revision in currents
config.print_stdout(
sc.cmd_format(
verbose=verbose,
include_branches=True,
include_doc=True,
include_parents=True,
)
)
def _display_history_w_current(config, script, base, head):
def _display_current_history(rev, context):
if head == "current":
_display_history(config, script, base, rev, rev)
elif base == "current":
_display_history(config, script, rev, head, rev)
else:
_display_history(config, script, base, head, rev)
return []
with EnvironmentContext(config, script, fn=_display_current_history):
script.run_env()
if base == "current" or head == "current" or environment:
_display_history_w_current(config, script, base, head)
else:
_display_history(config, script, base, head)
def heads(config, verbose=False, resolve_dependencies=False):
"""Show current available heads in the script directory.
:param config: a :class:`.Config` instance.
:param verbose: output in verbose mode.
:param resolve_dependencies: treat dependency version as down revisions.
"""
script = ScriptDirectory.from_config(config)
if resolve_dependencies:
heads = script.get_revisions("heads")
else:
heads = script.get_revisions(script.get_heads())
for rev in heads:
config.print_stdout(
rev.cmd_format(
verbose, include_branches=True, tree_indicators=False
)
)
def branches(config, verbose=False):
"""Show current branch points.
:param config: a :class:`.Config` instance.
:param verbose: output in verbose mode.
"""
script = ScriptDirectory.from_config(config)
for sc in script.walk_revisions():
if sc.is_branch_point:
config.print_stdout(
"%s\n%s\n",
sc.cmd_format(verbose, include_branches=True),
"\n".join(
"%s -> %s"
% (
" " * len(str(sc.revision)),
rev_obj.cmd_format(
False, include_branches=True, include_doc=verbose
),
)
for rev_obj in (
script.get_revision(rev) for rev in sc.nextrev
)
),
)
def current(config: Config, verbose: bool = False) -> None:
"""Display the current revision for a database.
:param config: a :class:`.Config` instance.
:param verbose: output in verbose mode.
"""
script = ScriptDirectory.from_config(config)
def display_version(rev, context):
if verbose:
config.print_stdout(
"Current revision(s) for %s:",
util.obfuscate_url_pw(context.connection.engine.url),
)
for rev in script.get_all_current(rev):
config.print_stdout(rev.cmd_format(verbose))
return []
with EnvironmentContext(
config, script, fn=display_version, dont_mutate=True
):
script.run_env()
def stamp(
config: Config,
revision: _RevIdType,
sql: bool = False,
tag: Optional[str] = None,
purge: bool = False,
) -> None:
"""'stamp' the revision table with the given revision; don't
run any migrations.
:param config: a :class:`.Config` instance.
:param revision: target revision or list of revisions. May be a list
to indicate stamping of multiple branch heads.
.. note:: this parameter is called "revisions" in the command line
interface.
:param sql: use ``--sql`` mode
:param tag: an arbitrary "tag" that can be intercepted by custom
``env.py`` scripts via the :class:`.EnvironmentContext.get_tag_argument`
method.
:param purge: delete all entries in the version table before stamping.
"""
script = ScriptDirectory.from_config(config)
if sql:
destination_revs = []
starting_rev = None
for _revision in util.to_list(revision):
if ":" in _revision:
srev, _revision = _revision.split(":", 2)
if starting_rev != srev:
if starting_rev is None:
starting_rev = srev
else:
raise util.CommandError(
"Stamp operation with --sql only supports a "
"single starting revision at a time"
)
destination_revs.append(_revision)
else:
destination_revs = util.to_list(revision)
def do_stamp(rev, context):
return script._stamp_revs(util.to_tuple(destination_revs), rev)
with EnvironmentContext(
config,
script,
fn=do_stamp,
as_sql=sql,
starting_rev=starting_rev if sql else None,
destination_rev=util.to_tuple(destination_revs),
tag=tag,
purge=purge,
):
script.run_env()
def edit(config: Config, rev: str) -> None:
"""Edit revision script(s) using $EDITOR.
:param config: a :class:`.Config` instance.
:param rev: target revision.
"""
script = ScriptDirectory.from_config(config)
if rev == "current":
def edit_current(rev, context):
if not rev:
raise util.CommandError("No current revisions")
for sc in script.get_revisions(rev):
util.open_in_editor(sc.path)
return []
with EnvironmentContext(config, script, fn=edit_current):
script.run_env()
else:
revs = script.get_revisions(rev)
if not revs:
raise util.CommandError(
"No revision files indicated by symbol '%s'" % rev
)
for sc in revs:
assert sc
util.open_in_editor(sc.path)
def ensure_version(config: Config, sql: bool = False) -> None:
"""Create the alembic version table if it doesn't exist already .
:param config: a :class:`.Config` instance.
:param sql: use ``--sql`` mode
.. versionadded:: 1.7.6
"""
script = ScriptDirectory.from_config(config)
def do_ensure_version(rev, context):
context._ensure_version_table()
return []
with EnvironmentContext(
config,
script,
fn=do_ensure_version,
as_sql=sql,
):
script.run_env()

View File

@@ -0,0 +1,645 @@
from __future__ import annotations
from argparse import ArgumentParser
from argparse import Namespace
from configparser import ConfigParser
import inspect
import os
import sys
from typing import Any
from typing import cast
from typing import Dict
from typing import Mapping
from typing import Optional
from typing import overload
from typing import Sequence
from typing import TextIO
from typing import Union
from typing_extensions import TypedDict
from . import __version__
from . import command
from . import util
from .util import compat
class Config:
r"""Represent an Alembic configuration.
Within an ``env.py`` script, this is available
via the :attr:`.EnvironmentContext.config` attribute,
which in turn is available at ``alembic.context``::
from alembic import context
some_param = context.config.get_main_option("my option")
When invoking Alembic programmatically, a new
:class:`.Config` can be created by passing
the name of an .ini file to the constructor::
from alembic.config import Config
alembic_cfg = Config("/path/to/yourapp/alembic.ini")
With a :class:`.Config` object, you can then
run Alembic commands programmatically using the directives
in :mod:`alembic.command`.
The :class:`.Config` object can also be constructed without
a filename. Values can be set programmatically, and
new sections will be created as needed::
from alembic.config import Config
alembic_cfg = Config()
alembic_cfg.set_main_option("script_location", "myapp:migrations")
alembic_cfg.set_main_option("sqlalchemy.url", "postgresql://foo/bar")
alembic_cfg.set_section_option("mysection", "foo", "bar")
.. warning::
When using programmatic configuration, make sure the
``env.py`` file in use is compatible with the target configuration;
including that the call to Python ``logging.fileConfig()`` is
omitted if the programmatic configuration doesn't actually include
logging directives.
For passing non-string values to environments, such as connections and
engines, use the :attr:`.Config.attributes` dictionary::
with engine.begin() as connection:
alembic_cfg.attributes['connection'] = connection
command.upgrade(alembic_cfg, "head")
:param file\_: name of the .ini file to open.
:param ini_section: name of the main Alembic section within the
.ini file
:param output_buffer: optional file-like input buffer which
will be passed to the :class:`.MigrationContext` - used to redirect
the output of "offline generation" when using Alembic programmatically.
:param stdout: buffer where the "print" output of commands will be sent.
Defaults to ``sys.stdout``.
:param config_args: A dictionary of keys and values that will be used
for substitution in the alembic config file. The dictionary as given
is **copied** to a new one, stored locally as the attribute
``.config_args``. When the :attr:`.Config.file_config` attribute is
first invoked, the replacement variable ``here`` will be added to this
dictionary before the dictionary is passed to ``ConfigParser()``
to parse the .ini file.
:param attributes: optional dictionary of arbitrary Python keys/values,
which will be populated into the :attr:`.Config.attributes` dictionary.
.. seealso::
:ref:`connection_sharing`
"""
def __init__(
self,
file_: Union[str, os.PathLike[str], None] = None,
ini_section: str = "alembic",
output_buffer: Optional[TextIO] = None,
stdout: TextIO = sys.stdout,
cmd_opts: Optional[Namespace] = None,
config_args: Mapping[str, Any] = util.immutabledict(),
attributes: Optional[Dict[str, Any]] = None,
) -> None:
"""Construct a new :class:`.Config`"""
self.config_file_name = file_
self.config_ini_section = ini_section
self.output_buffer = output_buffer
self.stdout = stdout
self.cmd_opts = cmd_opts
self.config_args = dict(config_args)
if attributes:
self.attributes.update(attributes)
cmd_opts: Optional[Namespace] = None
"""The command-line options passed to the ``alembic`` script.
Within an ``env.py`` script this can be accessed via the
:attr:`.EnvironmentContext.config` attribute.
.. seealso::
:meth:`.EnvironmentContext.get_x_argument`
"""
config_file_name: Union[str, os.PathLike[str], None] = None
"""Filesystem path to the .ini file in use."""
config_ini_section: str = None # type:ignore[assignment]
"""Name of the config file section to read basic configuration
from. Defaults to ``alembic``, that is the ``[alembic]`` section
of the .ini file. This value is modified using the ``-n/--name``
option to the Alembic runner.
"""
@util.memoized_property
def attributes(self) -> Dict[str, Any]:
"""A Python dictionary for storage of additional state.
This is a utility dictionary which can include not just strings but
engines, connections, schema objects, or anything else.
Use this to pass objects into an env.py script, such as passing
a :class:`sqlalchemy.engine.base.Connection` when calling
commands from :mod:`alembic.command` programmatically.
.. seealso::
:ref:`connection_sharing`
:paramref:`.Config.attributes`
"""
return {}
def print_stdout(self, text: str, *arg: Any) -> None:
"""Render a message to standard out.
When :meth:`.Config.print_stdout` is called with additional args
those arguments will formatted against the provided text,
otherwise we simply output the provided text verbatim.
This is a no-op when the``quiet`` messaging option is enabled.
e.g.::
>>> config.print_stdout('Some text %s', 'arg')
Some Text arg
"""
if arg:
output = str(text) % arg
else:
output = str(text)
util.write_outstream(self.stdout, output, "\n", **self.messaging_opts)
@util.memoized_property
def file_config(self) -> ConfigParser:
"""Return the underlying ``ConfigParser`` object.
Direct access to the .ini file is available here,
though the :meth:`.Config.get_section` and
:meth:`.Config.get_main_option`
methods provide a possibly simpler interface.
"""
if self.config_file_name:
here = os.path.abspath(os.path.dirname(self.config_file_name))
else:
here = ""
self.config_args["here"] = here
file_config = ConfigParser(self.config_args)
if self.config_file_name:
compat.read_config_parser(file_config, [self.config_file_name])
else:
file_config.add_section(self.config_ini_section)
return file_config
def get_template_directory(self) -> str:
"""Return the directory where Alembic setup templates are found.
This method is used by the alembic ``init`` and ``list_templates``
commands.
"""
import alembic
package_dir = os.path.abspath(os.path.dirname(alembic.__file__))
return os.path.join(package_dir, "templates")
@overload
def get_section(
self, name: str, default: None = ...
) -> Optional[Dict[str, str]]:
...
# "default" here could also be a TypeVar
# _MT = TypeVar("_MT", bound=Mapping[str, str]),
# however mypy wasn't handling that correctly (pyright was)
@overload
def get_section(
self, name: str, default: Dict[str, str]
) -> Dict[str, str]:
...
@overload
def get_section(
self, name: str, default: Mapping[str, str]
) -> Union[Dict[str, str], Mapping[str, str]]:
...
def get_section(
self, name: str, default: Optional[Mapping[str, str]] = None
) -> Optional[Mapping[str, str]]:
"""Return all the configuration options from a given .ini file section
as a dictionary.
If the given section does not exist, the value of ``default``
is returned, which is expected to be a dictionary or other mapping.
"""
if not self.file_config.has_section(name):
return default
return dict(self.file_config.items(name))
def set_main_option(self, name: str, value: str) -> None:
"""Set an option programmatically within the 'main' section.
This overrides whatever was in the .ini file.
:param name: name of the value
:param value: the value. Note that this value is passed to
``ConfigParser.set``, which supports variable interpolation using
pyformat (e.g. ``%(some_value)s``). A raw percent sign not part of
an interpolation symbol must therefore be escaped, e.g. ``%%``.
The given value may refer to another value already in the file
using the interpolation format.
"""
self.set_section_option(self.config_ini_section, name, value)
def remove_main_option(self, name: str) -> None:
self.file_config.remove_option(self.config_ini_section, name)
def set_section_option(self, section: str, name: str, value: str) -> None:
"""Set an option programmatically within the given section.
The section is created if it doesn't exist already.
The value here will override whatever was in the .ini
file.
:param section: name of the section
:param name: name of the value
:param value: the value. Note that this value is passed to
``ConfigParser.set``, which supports variable interpolation using
pyformat (e.g. ``%(some_value)s``). A raw percent sign not part of
an interpolation symbol must therefore be escaped, e.g. ``%%``.
The given value may refer to another value already in the file
using the interpolation format.
"""
if not self.file_config.has_section(section):
self.file_config.add_section(section)
self.file_config.set(section, name, value)
def get_section_option(
self, section: str, name: str, default: Optional[str] = None
) -> Optional[str]:
"""Return an option from the given section of the .ini file."""
if not self.file_config.has_section(section):
raise util.CommandError(
"No config file %r found, or file has no "
"'[%s]' section" % (self.config_file_name, section)
)
if self.file_config.has_option(section, name):
return self.file_config.get(section, name)
else:
return default
@overload
def get_main_option(self, name: str, default: str) -> str:
...
@overload
def get_main_option(
self, name: str, default: Optional[str] = None
) -> Optional[str]:
...
def get_main_option(
self, name: str, default: Optional[str] = None
) -> Optional[str]:
"""Return an option from the 'main' section of the .ini file.
This defaults to being a key from the ``[alembic]``
section, unless the ``-n/--name`` flag were used to
indicate a different section.
"""
return self.get_section_option(self.config_ini_section, name, default)
@util.memoized_property
def messaging_opts(self) -> MessagingOptions:
"""The messaging options."""
return cast(
MessagingOptions,
util.immutabledict(
{"quiet": getattr(self.cmd_opts, "quiet", False)}
),
)
class MessagingOptions(TypedDict, total=False):
quiet: bool
class CommandLine:
def __init__(self, prog: Optional[str] = None) -> None:
self._generate_args(prog)
def _generate_args(self, prog: Optional[str]) -> None:
def add_options(
fn: Any, parser: Any, positional: Any, kwargs: Any
) -> None:
kwargs_opts = {
"template": (
"-t",
"--template",
dict(
default="generic",
type=str,
help="Setup template for use with 'init'",
),
),
"message": (
"-m",
"--message",
dict(
type=str, help="Message string to use with 'revision'"
),
),
"sql": (
"--sql",
dict(
action="store_true",
help="Don't emit SQL to database - dump to "
"standard output/file instead. See docs on "
"offline mode.",
),
),
"tag": (
"--tag",
dict(
type=str,
help="Arbitrary 'tag' name - can be used by "
"custom env.py scripts.",
),
),
"head": (
"--head",
dict(
type=str,
help="Specify head revision or <branchname>@head "
"to base new revision on.",
),
),
"splice": (
"--splice",
dict(
action="store_true",
help="Allow a non-head revision as the "
"'head' to splice onto",
),
),
"depends_on": (
"--depends-on",
dict(
action="append",
help="Specify one or more revision identifiers "
"which this revision should depend on.",
),
),
"rev_id": (
"--rev-id",
dict(
type=str,
help="Specify a hardcoded revision id instead of "
"generating one",
),
),
"version_path": (
"--version-path",
dict(
type=str,
help="Specify specific path from config for "
"version file",
),
),
"branch_label": (
"--branch-label",
dict(
type=str,
help="Specify a branch label to apply to the "
"new revision",
),
),
"verbose": (
"-v",
"--verbose",
dict(action="store_true", help="Use more verbose output"),
),
"resolve_dependencies": (
"--resolve-dependencies",
dict(
action="store_true",
help="Treat dependency versions as down revisions",
),
),
"autogenerate": (
"--autogenerate",
dict(
action="store_true",
help="Populate revision script with candidate "
"migration operations, based on comparison "
"of database to model.",
),
),
"rev_range": (
"-r",
"--rev-range",
dict(
action="store",
help="Specify a revision range; "
"format is [start]:[end]",
),
),
"indicate_current": (
"-i",
"--indicate-current",
dict(
action="store_true",
help="Indicate the current revision",
),
),
"purge": (
"--purge",
dict(
action="store_true",
help="Unconditionally erase the version table "
"before stamping",
),
),
"package": (
"--package",
dict(
action="store_true",
help="Write empty __init__.py files to the "
"environment and version locations",
),
),
}
positional_help = {
"directory": "location of scripts directory",
"revision": "revision identifier",
"revisions": "one or more revisions, or 'heads' for all heads",
}
for arg in kwargs:
if arg in kwargs_opts:
args = kwargs_opts[arg]
args, kw = args[0:-1], args[-1]
parser.add_argument(*args, **kw)
for arg in positional:
if (
arg == "revisions"
or fn in positional_translations
and positional_translations[fn][arg] == "revisions"
):
subparser.add_argument(
"revisions",
nargs="+",
help=positional_help.get("revisions"),
)
else:
subparser.add_argument(arg, help=positional_help.get(arg))
parser = ArgumentParser(prog=prog)
parser.add_argument(
"--version", action="version", version="%%(prog)s %s" % __version__
)
parser.add_argument(
"-c",
"--config",
type=str,
default=os.environ.get("ALEMBIC_CONFIG", "alembic.ini"),
help="Alternate config file; defaults to value of "
'ALEMBIC_CONFIG environment variable, or "alembic.ini"',
)
parser.add_argument(
"-n",
"--name",
type=str,
default="alembic",
help="Name of section in .ini file to " "use for Alembic config",
)
parser.add_argument(
"-x",
action="append",
help="Additional arguments consumed by "
"custom env.py scripts, e.g. -x "
"setting1=somesetting -x setting2=somesetting",
)
parser.add_argument(
"--raiseerr",
action="store_true",
help="Raise a full stack trace on error",
)
parser.add_argument(
"-q",
"--quiet",
action="store_true",
help="Do not log to std output.",
)
subparsers = parser.add_subparsers()
positional_translations: Dict[Any, Any] = {
command.stamp: {"revision": "revisions"}
}
for fn in [getattr(command, n) for n in dir(command)]:
if (
inspect.isfunction(fn)
and fn.__name__[0] != "_"
and fn.__module__ == "alembic.command"
):
spec = compat.inspect_getfullargspec(fn)
if spec[3] is not None:
positional = spec[0][1 : -len(spec[3])]
kwarg = spec[0][-len(spec[3]) :]
else:
positional = spec[0][1:]
kwarg = []
if fn in positional_translations:
positional = [
positional_translations[fn].get(name, name)
for name in positional
]
# parse first line(s) of helptext without a line break
help_ = fn.__doc__
if help_:
help_text = []
for line in help_.split("\n"):
if not line.strip():
break
else:
help_text.append(line.strip())
else:
help_text = []
subparser = subparsers.add_parser(
fn.__name__, help=" ".join(help_text)
)
add_options(fn, subparser, positional, kwarg)
subparser.set_defaults(cmd=(fn, positional, kwarg))
self.parser = parser
def run_cmd(self, config: Config, options: Namespace) -> None:
fn, positional, kwarg = options.cmd
try:
fn(
config,
*[getattr(options, k, None) for k in positional],
**{k: getattr(options, k, None) for k in kwarg},
)
except util.CommandError as e:
if options.raiseerr:
raise
else:
util.err(str(e), **config.messaging_opts)
def main(self, argv: Optional[Sequence[str]] = None) -> None:
options = self.parser.parse_args(argv)
if not hasattr(options, "cmd"):
# see http://bugs.python.org/issue9253, argparse
# behavior changed incompatibly in py3.3
self.parser.error("too few arguments")
else:
cfg = Config(
file_=options.config,
ini_section=options.name,
cmd_opts=options,
)
self.run_cmd(cfg, options)
def main(
argv: Optional[Sequence[str]] = None,
prog: Optional[str] = None,
**kwargs: Any,
) -> None:
"""The console runner function for Alembic."""
CommandLine(prog=prog).main(argv=argv)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,5 @@
from .runtime.environment import EnvironmentContext
# create proxy functions for
# each method on the EnvironmentContext class.
EnvironmentContext.create_module_class_proxy(globals(), locals())

View File

@@ -0,0 +1,853 @@
# ### this file stubs are generated by tools/write_pyi.py - do not edit ###
# ### imports are manually managed
from __future__ import annotations
from typing import Any
from typing import Callable
from typing import Collection
from typing import ContextManager
from typing import Dict
from typing import Iterable
from typing import List
from typing import Literal
from typing import Mapping
from typing import MutableMapping
from typing import Optional
from typing import overload
from typing import Sequence
from typing import TextIO
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
if TYPE_CHECKING:
from sqlalchemy.engine.base import Connection
from sqlalchemy.engine.url import URL
from sqlalchemy.sql import Executable
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import FetchedValue
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import SchemaItem
from sqlalchemy.sql.type_api import TypeEngine
from .autogenerate.api import AutogenContext
from .config import Config
from .operations.ops import MigrationScript
from .runtime.migration import _ProxyTransaction
from .runtime.migration import MigrationContext
from .runtime.migration import MigrationInfo
from .script import ScriptDirectory
### end imports ###
def begin_transaction() -> Union[_ProxyTransaction, ContextManager[None]]:
"""Return a context manager that will
enclose an operation within a "transaction",
as defined by the environment's offline
and transactional DDL settings.
e.g.::
with context.begin_transaction():
context.run_migrations()
:meth:`.begin_transaction` is intended to
"do the right thing" regardless of
calling context:
* If :meth:`.is_transactional_ddl` is ``False``,
returns a "do nothing" context manager
which otherwise produces no transactional
state or directives.
* If :meth:`.is_offline_mode` is ``True``,
returns a context manager that will
invoke the :meth:`.DefaultImpl.emit_begin`
and :meth:`.DefaultImpl.emit_commit`
methods, which will produce the string
directives ``BEGIN`` and ``COMMIT`` on
the output stream, as rendered by the
target backend (e.g. SQL Server would
emit ``BEGIN TRANSACTION``).
* Otherwise, calls :meth:`sqlalchemy.engine.Connection.begin`
on the current online connection, which
returns a :class:`sqlalchemy.engine.Transaction`
object. This object demarcates a real
transaction and is itself a context manager,
which will roll back if an exception
is raised.
Note that a custom ``env.py`` script which
has more specific transactional needs can of course
manipulate the :class:`~sqlalchemy.engine.Connection`
directly to produce transactional state in "online"
mode.
"""
config: Config
def configure(
connection: Optional[Connection] = None,
url: Union[str, URL, None] = None,
dialect_name: Optional[str] = None,
dialect_opts: Optional[Dict[str, Any]] = None,
transactional_ddl: Optional[bool] = None,
transaction_per_migration: bool = False,
output_buffer: Optional[TextIO] = None,
starting_rev: Optional[str] = None,
tag: Optional[str] = None,
template_args: Optional[Dict[str, Any]] = None,
render_as_batch: bool = False,
target_metadata: Union[MetaData, Sequence[MetaData], None] = None,
include_name: Optional[
Callable[
[
Optional[str],
Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
MutableMapping[
Literal[
"schema_name",
"table_name",
"schema_qualified_table_name",
],
Optional[str],
],
],
bool,
]
] = None,
include_object: Optional[
Callable[
[
SchemaItem,
Optional[str],
Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
bool,
Optional[SchemaItem],
],
bool,
]
] = None,
include_schemas: bool = False,
process_revision_directives: Optional[
Callable[
[
MigrationContext,
Union[str, Iterable[Optional[str]], Iterable[str]],
List[MigrationScript],
],
None,
]
] = None,
compare_type: Union[
bool,
Callable[
[
MigrationContext,
Column[Any],
Column[Any],
TypeEngine[Any],
TypeEngine[Any],
],
Optional[bool],
],
] = True,
compare_server_default: Union[
bool,
Callable[
[
MigrationContext,
Column[Any],
Column[Any],
Optional[str],
Optional[FetchedValue],
Optional[str],
],
Optional[bool],
],
] = False,
render_item: Optional[
Callable[[str, Any, AutogenContext], Union[str, Literal[False]]]
] = None,
literal_binds: bool = False,
upgrade_token: str = "upgrades",
downgrade_token: str = "downgrades",
alembic_module_prefix: str = "op.",
sqlalchemy_module_prefix: str = "sa.",
user_module_prefix: Optional[str] = None,
on_version_apply: Optional[
Callable[
[
MigrationContext,
MigrationInfo,
Collection[Any],
Mapping[str, Any],
],
None,
]
] = None,
**kw: Any,
) -> None:
"""Configure a :class:`.MigrationContext` within this
:class:`.EnvironmentContext` which will provide database
connectivity and other configuration to a series of
migration scripts.
Many methods on :class:`.EnvironmentContext` require that
this method has been called in order to function, as they
ultimately need to have database access or at least access
to the dialect in use. Those which do are documented as such.
The important thing needed by :meth:`.configure` is a
means to determine what kind of database dialect is in use.
An actual connection to that database is needed only if
the :class:`.MigrationContext` is to be used in
"online" mode.
If the :meth:`.is_offline_mode` function returns ``True``,
then no connection is needed here. Otherwise, the
``connection`` parameter should be present as an
instance of :class:`sqlalchemy.engine.Connection`.
This function is typically called from the ``env.py``
script within a migration environment. It can be called
multiple times for an invocation. The most recent
:class:`~sqlalchemy.engine.Connection`
for which it was called is the one that will be operated upon
by the next call to :meth:`.run_migrations`.
General parameters:
:param connection: a :class:`~sqlalchemy.engine.Connection`
to use
for SQL execution in "online" mode. When present, is also
used to determine the type of dialect in use.
:param url: a string database url, or a
:class:`sqlalchemy.engine.url.URL` object.
The type of dialect to be used will be derived from this if
``connection`` is not passed.
:param dialect_name: string name of a dialect, such as
"postgresql", "mssql", etc.
The type of dialect to be used will be derived from this if
``connection`` and ``url`` are not passed.
:param dialect_opts: dictionary of options to be passed to dialect
constructor.
:param transactional_ddl: Force the usage of "transactional"
DDL on or off;
this otherwise defaults to whether or not the dialect in
use supports it.
:param transaction_per_migration: if True, nest each migration script
in a transaction rather than the full series of migrations to
run.
:param output_buffer: a file-like object that will be used
for textual output
when the ``--sql`` option is used to generate SQL scripts.
Defaults to
``sys.stdout`` if not passed here and also not present on
the :class:`.Config`
object. The value here overrides that of the :class:`.Config`
object.
:param output_encoding: when using ``--sql`` to generate SQL
scripts, apply this encoding to the string output.
:param literal_binds: when using ``--sql`` to generate SQL
scripts, pass through the ``literal_binds`` flag to the compiler
so that any literal values that would ordinarily be bound
parameters are converted to plain strings.
.. warning:: Dialects can typically only handle simple datatypes
like strings and numbers for auto-literal generation. Datatypes
like dates, intervals, and others may still require manual
formatting, typically using :meth:`.Operations.inline_literal`.
.. note:: the ``literal_binds`` flag is ignored on SQLAlchemy
versions prior to 0.8 where this feature is not supported.
.. seealso::
:meth:`.Operations.inline_literal`
:param starting_rev: Override the "starting revision" argument
when using ``--sql`` mode.
:param tag: a string tag for usage by custom ``env.py`` scripts.
Set via the ``--tag`` option, can be overridden here.
:param template_args: dictionary of template arguments which
will be added to the template argument environment when
running the "revision" command. Note that the script environment
is only run within the "revision" command if the --autogenerate
option is used, or if the option "revision_environment=true"
is present in the alembic.ini file.
:param version_table: The name of the Alembic version table.
The default is ``'alembic_version'``.
:param version_table_schema: Optional schema to place version
table within.
:param version_table_pk: boolean, whether the Alembic version table
should use a primary key constraint for the "value" column; this
only takes effect when the table is first created.
Defaults to True; setting to False should not be necessary and is
here for backwards compatibility reasons.
:param on_version_apply: a callable or collection of callables to be
run for each migration step.
The callables will be run in the order they are given, once for
each migration step, after the respective operation has been
applied but before its transaction is finalized.
Each callable accepts no positional arguments and the following
keyword arguments:
* ``ctx``: the :class:`.MigrationContext` running the migration,
* ``step``: a :class:`.MigrationInfo` representing the
step currently being applied,
* ``heads``: a collection of version strings representing the
current heads,
* ``run_args``: the ``**kwargs`` passed to :meth:`.run_migrations`.
Parameters specific to the autogenerate feature, when
``alembic revision`` is run with the ``--autogenerate`` feature:
:param target_metadata: a :class:`sqlalchemy.schema.MetaData`
object, or a sequence of :class:`~sqlalchemy.schema.MetaData`
objects, that will be consulted during autogeneration.
The tables present in each :class:`~sqlalchemy.schema.MetaData`
will be compared against
what is locally available on the target
:class:`~sqlalchemy.engine.Connection`
to produce candidate upgrade/downgrade operations.
:param compare_type: Indicates type comparison behavior during
an autogenerate
operation. Defaults to ``True`` turning on type comparison, which
has good accuracy on most backends. See :ref:`compare_types`
for an example as well as information on other type
comparison options. Set to ``False`` which disables type
comparison. A callable can also be passed to provide custom type
comparison, see :ref:`compare_types` for additional details.
.. versionchanged:: 1.12.0 The default value of
:paramref:`.EnvironmentContext.configure.compare_type` has been
changed to ``True``.
.. seealso::
:ref:`compare_types`
:paramref:`.EnvironmentContext.configure.compare_server_default`
:param compare_server_default: Indicates server default comparison
behavior during
an autogenerate operation. Defaults to ``False`` which disables
server default
comparison. Set to ``True`` to turn on server default comparison,
which has
varied accuracy depending on backend.
To customize server default comparison behavior, a callable may
be specified
which can filter server default comparisons during an
autogenerate operation.
defaults during an autogenerate operation. The format of this
callable is::
def my_compare_server_default(context, inspected_column,
metadata_column, inspected_default, metadata_default,
rendered_metadata_default):
# return True if the defaults are different,
# False if not, or None to allow the default implementation
# to compare these defaults
return None
context.configure(
# ...
compare_server_default = my_compare_server_default
)
``inspected_column`` is a dictionary structure as returned by
:meth:`sqlalchemy.engine.reflection.Inspector.get_columns`, whereas
``metadata_column`` is a :class:`sqlalchemy.schema.Column` from
the local model environment.
A return value of ``None`` indicates to allow default server default
comparison
to proceed. Note that some backends such as Postgresql actually
execute
the two defaults on the database side to compare for equivalence.
.. seealso::
:paramref:`.EnvironmentContext.configure.compare_type`
:param include_name: A callable function which is given
the chance to return ``True`` or ``False`` for any database reflected
object based on its name, including database schema names when
the :paramref:`.EnvironmentContext.configure.include_schemas` flag
is set to ``True``.
The function accepts the following positional arguments:
* ``name``: the name of the object, such as schema name or table name.
Will be ``None`` when indicating the default schema name of the
database connection.
* ``type``: a string describing the type of object; currently
``"schema"``, ``"table"``, ``"column"``, ``"index"``,
``"unique_constraint"``, or ``"foreign_key_constraint"``
* ``parent_names``: a dictionary of "parent" object names, that are
relative to the name being given. Keys in this dictionary may
include: ``"schema_name"``, ``"table_name"`` or
``"schema_qualified_table_name"``.
E.g.::
def include_name(name, type_, parent_names):
if type_ == "schema":
return name in ["schema_one", "schema_two"]
else:
return True
context.configure(
# ...
include_schemas = True,
include_name = include_name
)
.. seealso::
:ref:`autogenerate_include_hooks`
:paramref:`.EnvironmentContext.configure.include_object`
:paramref:`.EnvironmentContext.configure.include_schemas`
:param include_object: A callable function which is given
the chance to return ``True`` or ``False`` for any object,
indicating if the given object should be considered in the
autogenerate sweep.
The function accepts the following positional arguments:
* ``object``: a :class:`~sqlalchemy.schema.SchemaItem` object such
as a :class:`~sqlalchemy.schema.Table`,
:class:`~sqlalchemy.schema.Column`,
:class:`~sqlalchemy.schema.Index`
:class:`~sqlalchemy.schema.UniqueConstraint`,
or :class:`~sqlalchemy.schema.ForeignKeyConstraint` object
* ``name``: the name of the object. This is typically available
via ``object.name``.
* ``type``: a string describing the type of object; currently
``"table"``, ``"column"``, ``"index"``, ``"unique_constraint"``,
or ``"foreign_key_constraint"``
* ``reflected``: ``True`` if the given object was produced based on
table reflection, ``False`` if it's from a local :class:`.MetaData`
object.
* ``compare_to``: the object being compared against, if available,
else ``None``.
E.g.::
def include_object(object, name, type_, reflected, compare_to):
if (type_ == "column" and
not reflected and
object.info.get("skip_autogenerate", False)):
return False
else:
return True
context.configure(
# ...
include_object = include_object
)
For the use case of omitting specific schemas from a target database
when :paramref:`.EnvironmentContext.configure.include_schemas` is
set to ``True``, the :attr:`~sqlalchemy.schema.Table.schema`
attribute can be checked for each :class:`~sqlalchemy.schema.Table`
object passed to the hook, however it is much more efficient
to filter on schemas before reflection of objects takes place
using the :paramref:`.EnvironmentContext.configure.include_name`
hook.
.. seealso::
:ref:`autogenerate_include_hooks`
:paramref:`.EnvironmentContext.configure.include_name`
:paramref:`.EnvironmentContext.configure.include_schemas`
:param render_as_batch: if True, commands which alter elements
within a table will be placed under a ``with batch_alter_table():``
directive, so that batch migrations will take place.
.. seealso::
:ref:`batch_migrations`
:param include_schemas: If True, autogenerate will scan across
all schemas located by the SQLAlchemy
:meth:`~sqlalchemy.engine.reflection.Inspector.get_schema_names`
method, and include all differences in tables found across all
those schemas. When using this option, you may want to also
use the :paramref:`.EnvironmentContext.configure.include_name`
parameter to specify a callable which
can filter the tables/schemas that get included.
.. seealso::
:ref:`autogenerate_include_hooks`
:paramref:`.EnvironmentContext.configure.include_name`
:paramref:`.EnvironmentContext.configure.include_object`
:param render_item: Callable that can be used to override how
any schema item, i.e. column, constraint, type,
etc., is rendered for autogenerate. The callable receives a
string describing the type of object, the object, and
the autogen context. If it returns False, the
default rendering method will be used. If it returns None,
the item will not be rendered in the context of a Table
construct, that is, can be used to skip columns or constraints
within op.create_table()::
def my_render_column(type_, col, autogen_context):
if type_ == "column" and isinstance(col, MySpecialCol):
return repr(col)
else:
return False
context.configure(
# ...
render_item = my_render_column
)
Available values for the type string include: ``"column"``,
``"primary_key"``, ``"foreign_key"``, ``"unique"``, ``"check"``,
``"type"``, ``"server_default"``.
.. seealso::
:ref:`autogen_render_types`
:param upgrade_token: When autogenerate completes, the text of the
candidate upgrade operations will be present in this template
variable when ``script.py.mako`` is rendered. Defaults to
``upgrades``.
:param downgrade_token: When autogenerate completes, the text of the
candidate downgrade operations will be present in this
template variable when ``script.py.mako`` is rendered. Defaults to
``downgrades``.
:param alembic_module_prefix: When autogenerate refers to Alembic
:mod:`alembic.operations` constructs, this prefix will be used
(i.e. ``op.create_table``) Defaults to "``op.``".
Can be ``None`` to indicate no prefix.
:param sqlalchemy_module_prefix: When autogenerate refers to
SQLAlchemy
:class:`~sqlalchemy.schema.Column` or type classes, this prefix
will be used
(i.e. ``sa.Column("somename", sa.Integer)``) Defaults to "``sa.``".
Can be ``None`` to indicate no prefix.
Note that when dialect-specific types are rendered, autogenerate
will render them using the dialect module name, i.e. ``mssql.BIT()``,
``postgresql.UUID()``.
:param user_module_prefix: When autogenerate refers to a SQLAlchemy
type (e.g. :class:`.TypeEngine`) where the module name is not
under the ``sqlalchemy`` namespace, this prefix will be used
within autogenerate. If left at its default of
``None``, the ``__module__`` attribute of the type is used to
render the import module. It's a good practice to set this
and to have all custom types be available from a fixed module space,
in order to future-proof migration files against reorganizations
in modules.
.. seealso::
:ref:`autogen_module_prefix`
:param process_revision_directives: a callable function that will
be passed a structure representing the end result of an autogenerate
or plain "revision" operation, which can be manipulated to affect
how the ``alembic revision`` command ultimately outputs new
revision scripts. The structure of the callable is::
def process_revision_directives(context, revision, directives):
pass
The ``directives`` parameter is a Python list containing
a single :class:`.MigrationScript` directive, which represents
the revision file to be generated. This list as well as its
contents may be freely modified to produce any set of commands.
The section :ref:`customizing_revision` shows an example of
doing this. The ``context`` parameter is the
:class:`.MigrationContext` in use,
and ``revision`` is a tuple of revision identifiers representing the
current revision of the database.
The callable is invoked at all times when the ``--autogenerate``
option is passed to ``alembic revision``. If ``--autogenerate``
is not passed, the callable is invoked only if the
``revision_environment`` variable is set to True in the Alembic
configuration, in which case the given ``directives`` collection
will contain empty :class:`.UpgradeOps` and :class:`.DowngradeOps`
collections for ``.upgrade_ops`` and ``.downgrade_ops``. The
``--autogenerate`` option itself can be inferred by inspecting
``context.config.cmd_opts.autogenerate``.
The callable function may optionally be an instance of
a :class:`.Rewriter` object. This is a helper object that
assists in the production of autogenerate-stream rewriter functions.
.. seealso::
:ref:`customizing_revision`
:ref:`autogen_rewriter`
:paramref:`.command.revision.process_revision_directives`
Parameters specific to individual backends:
:param mssql_batch_separator: The "batch separator" which will
be placed between each statement when generating offline SQL Server
migrations. Defaults to ``GO``. Note this is in addition to the
customary semicolon ``;`` at the end of each statement; SQL Server
considers the "batch separator" to denote the end of an
individual statement execution, and cannot group certain
dependent operations in one step.
:param oracle_batch_separator: The "batch separator" which will
be placed between each statement when generating offline
Oracle migrations. Defaults to ``/``. Oracle doesn't add a
semicolon between statements like most other backends.
"""
def execute(
sql: Union[Executable, str],
execution_options: Optional[Dict[str, Any]] = None,
) -> None:
"""Execute the given SQL using the current change context.
The behavior of :meth:`.execute` is the same
as that of :meth:`.Operations.execute`. Please see that
function's documentation for full detail including
caveats and limitations.
This function requires that a :class:`.MigrationContext` has
first been made available via :meth:`.configure`.
"""
def get_bind() -> Connection:
"""Return the current 'bind'.
In "online" mode, this is the
:class:`sqlalchemy.engine.Connection` currently being used
to emit SQL to the database.
This function requires that a :class:`.MigrationContext`
has first been made available via :meth:`.configure`.
"""
def get_context() -> MigrationContext:
"""Return the current :class:`.MigrationContext` object.
If :meth:`.EnvironmentContext.configure` has not been
called yet, raises an exception.
"""
def get_head_revision() -> Union[str, Tuple[str, ...], None]:
"""Return the hex identifier of the 'head' script revision.
If the script directory has multiple heads, this
method raises a :class:`.CommandError`;
:meth:`.EnvironmentContext.get_head_revisions` should be preferred.
This function does not require that the :class:`.MigrationContext`
has been configured.
.. seealso:: :meth:`.EnvironmentContext.get_head_revisions`
"""
def get_head_revisions() -> Union[str, Tuple[str, ...], None]:
"""Return the hex identifier of the 'heads' script revision(s).
This returns a tuple containing the version number of all
heads in the script directory.
This function does not require that the :class:`.MigrationContext`
has been configured.
"""
def get_revision_argument() -> Union[str, Tuple[str, ...], None]:
"""Get the 'destination' revision argument.
This is typically the argument passed to the
``upgrade`` or ``downgrade`` command.
If it was specified as ``head``, the actual
version number is returned; if specified
as ``base``, ``None`` is returned.
This function does not require that the :class:`.MigrationContext`
has been configured.
"""
def get_starting_revision_argument() -> Union[str, Tuple[str, ...], None]:
"""Return the 'starting revision' argument,
if the revision was passed using ``start:end``.
This is only meaningful in "offline" mode.
Returns ``None`` if no value is available
or was configured.
This function does not require that the :class:`.MigrationContext`
has been configured.
"""
def get_tag_argument() -> Optional[str]:
"""Return the value passed for the ``--tag`` argument, if any.
The ``--tag`` argument is not used directly by Alembic,
but is available for custom ``env.py`` configurations that
wish to use it; particularly for offline generation scripts
that wish to generate tagged filenames.
This function does not require that the :class:`.MigrationContext`
has been configured.
.. seealso::
:meth:`.EnvironmentContext.get_x_argument` - a newer and more
open ended system of extending ``env.py`` scripts via the command
line.
"""
@overload
def get_x_argument(as_dictionary: Literal[False]) -> List[str]: ...
@overload
def get_x_argument(as_dictionary: Literal[True]) -> Dict[str, str]: ...
@overload
def get_x_argument(
as_dictionary: bool = ...,
) -> Union[List[str], Dict[str, str]]:
"""Return the value(s) passed for the ``-x`` argument, if any.
The ``-x`` argument is an open ended flag that allows any user-defined
value or values to be passed on the command line, then available
here for consumption by a custom ``env.py`` script.
The return value is a list, returned directly from the ``argparse``
structure. If ``as_dictionary=True`` is passed, the ``x`` arguments
are parsed using ``key=value`` format into a dictionary that is
then returned. If there is no ``=`` in the argument, value is an empty
string.
.. versionchanged:: 1.13.1 Support ``as_dictionary=True`` when
arguments are passed without the ``=`` symbol.
For example, to support passing a database URL on the command line,
the standard ``env.py`` script can be modified like this::
cmd_line_url = context.get_x_argument(
as_dictionary=True).get('dbname')
if cmd_line_url:
engine = create_engine(cmd_line_url)
else:
engine = engine_from_config(
config.get_section(config.config_ini_section),
prefix='sqlalchemy.',
poolclass=pool.NullPool)
This then takes effect by running the ``alembic`` script as::
alembic -x dbname=postgresql://user:pass@host/dbname upgrade head
This function does not require that the :class:`.MigrationContext`
has been configured.
.. seealso::
:meth:`.EnvironmentContext.get_tag_argument`
:attr:`.Config.cmd_opts`
"""
def is_offline_mode() -> bool:
"""Return True if the current migrations environment
is running in "offline mode".
This is ``True`` or ``False`` depending
on the ``--sql`` flag passed.
This function does not require that the :class:`.MigrationContext`
has been configured.
"""
def is_transactional_ddl() -> bool:
"""Return True if the context is configured to expect a
transactional DDL capable backend.
This defaults to the type of database in use, and
can be overridden by the ``transactional_ddl`` argument
to :meth:`.configure`
This function requires that a :class:`.MigrationContext`
has first been made available via :meth:`.configure`.
"""
def run_migrations(**kw: Any) -> None:
"""Run migrations as determined by the current command line
configuration
as well as versioning information present (or not) in the current
database connection (if one is present).
The function accepts optional ``**kw`` arguments. If these are
passed, they are sent directly to the ``upgrade()`` and
``downgrade()``
functions within each target revision file. By modifying the
``script.py.mako`` file so that the ``upgrade()`` and ``downgrade()``
functions accept arguments, parameters can be passed here so that
contextual information, usually information to identify a particular
database in use, can be passed from a custom ``env.py`` script
to the migration functions.
This function requires that a :class:`.MigrationContext` has
first been made available via :meth:`.configure`.
"""
script: ScriptDirectory
def static_output(text: str) -> None:
"""Emit text directly to the "offline" SQL stream.
Typically this is for emitting comments that
start with --. The statement is not treated
as a SQL execution, no ; or batch separator
is added, etc.
"""

View File

@@ -0,0 +1,6 @@
from . import mssql
from . import mysql
from . import oracle
from . import postgresql
from . import sqlite
from .impl import DefaultImpl as DefaultImpl

View File

@@ -0,0 +1,325 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
from typing import Any
from typing import ClassVar
from typing import Dict
from typing import Generic
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import UniqueConstraint
from typing_extensions import TypeGuard
from .. import util
from ..util import sqla_compat
if TYPE_CHECKING:
from typing import Literal
from alembic.autogenerate.api import AutogenContext
from alembic.ddl.impl import DefaultImpl
CompareConstraintType = Union[Constraint, Index]
_C = TypeVar("_C", bound=CompareConstraintType)
_clsreg: Dict[str, Type[_constraint_sig]] = {}
class ComparisonResult(NamedTuple):
status: Literal["equal", "different", "skip"]
message: str
@property
def is_equal(self) -> bool:
return self.status == "equal"
@property
def is_different(self) -> bool:
return self.status == "different"
@property
def is_skip(self) -> bool:
return self.status == "skip"
@classmethod
def Equal(cls) -> ComparisonResult:
"""the constraints are equal."""
return cls("equal", "The two constraints are equal")
@classmethod
def Different(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
"""the constraints are different for the provided reason(s)."""
return cls("different", ", ".join(util.to_list(reason)))
@classmethod
def Skip(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
"""the constraint cannot be compared for the provided reason(s).
The message is logged, but the constraints will be otherwise
considered equal, meaning that no migration command will be
generated.
"""
return cls("skip", ", ".join(util.to_list(reason)))
class _constraint_sig(Generic[_C]):
const: _C
_sig: Tuple[Any, ...]
name: Optional[sqla_compat._ConstraintNameDefined]
impl: DefaultImpl
_is_index: ClassVar[bool] = False
_is_fk: ClassVar[bool] = False
_is_uq: ClassVar[bool] = False
_is_metadata: bool
def __init_subclass__(cls) -> None:
cls._register()
@classmethod
def _register(cls):
raise NotImplementedError()
def __init__(
self, is_metadata: bool, impl: DefaultImpl, const: _C
) -> None:
raise NotImplementedError()
def compare_to_reflected(
self, other: _constraint_sig[Any]
) -> ComparisonResult:
assert self.impl is other.impl
assert self._is_metadata
assert not other._is_metadata
return self._compare_to_reflected(other)
def _compare_to_reflected(
self, other: _constraint_sig[_C]
) -> ComparisonResult:
raise NotImplementedError()
@classmethod
def from_constraint(
cls, is_metadata: bool, impl: DefaultImpl, constraint: _C
) -> _constraint_sig[_C]:
# these could be cached by constraint/impl, however, if the
# constraint is modified in place, then the sig is wrong. the mysql
# impl currently does this, and if we fixed that we can't be sure
# someone else might do it too, so play it safe.
sig = _clsreg[constraint.__visit_name__](is_metadata, impl, constraint)
return sig
def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
return sqla_compat._get_constraint_final_name(
self.const, context.dialect
)
@util.memoized_property
def is_named(self):
return sqla_compat._constraint_is_named(self.const, self.impl.dialect)
@util.memoized_property
def unnamed(self) -> Tuple[Any, ...]:
return self._sig
@util.memoized_property
def unnamed_no_options(self) -> Tuple[Any, ...]:
raise NotImplementedError()
@util.memoized_property
def _full_sig(self) -> Tuple[Any, ...]:
return (self.name,) + self.unnamed
def __eq__(self, other) -> bool:
return self._full_sig == other._full_sig
def __ne__(self, other) -> bool:
return self._full_sig != other._full_sig
def __hash__(self) -> int:
return hash(self._full_sig)
class _uq_constraint_sig(_constraint_sig[UniqueConstraint]):
_is_uq = True
@classmethod
def _register(cls) -> None:
_clsreg["unique_constraint"] = cls
is_unique = True
def __init__(
self,
is_metadata: bool,
impl: DefaultImpl,
const: UniqueConstraint,
) -> None:
self.impl = impl
self.const = const
self.name = sqla_compat.constraint_name_or_none(const.name)
self._sig = tuple(sorted([col.name for col in const.columns]))
self._is_metadata = is_metadata
@property
def column_names(self) -> Tuple[str, ...]:
return tuple([col.name for col in self.const.columns])
def _compare_to_reflected(
self, other: _constraint_sig[_C]
) -> ComparisonResult:
assert self._is_metadata
metadata_obj = self
conn_obj = other
assert is_uq_sig(conn_obj)
return self.impl.compare_unique_constraint(
metadata_obj.const, conn_obj.const
)
class _ix_constraint_sig(_constraint_sig[Index]):
_is_index = True
name: sqla_compat._ConstraintName
@classmethod
def _register(cls) -> None:
_clsreg["index"] = cls
def __init__(
self, is_metadata: bool, impl: DefaultImpl, const: Index
) -> None:
self.impl = impl
self.const = const
self.name = const.name
self.is_unique = bool(const.unique)
self._is_metadata = is_metadata
def _compare_to_reflected(
self, other: _constraint_sig[_C]
) -> ComparisonResult:
assert self._is_metadata
metadata_obj = self
conn_obj = other
assert is_index_sig(conn_obj)
return self.impl.compare_indexes(metadata_obj.const, conn_obj.const)
@util.memoized_property
def has_expressions(self):
return sqla_compat.is_expression_index(self.const)
@util.memoized_property
def column_names(self) -> Tuple[str, ...]:
return tuple([col.name for col in self.const.columns])
@util.memoized_property
def column_names_optional(self) -> Tuple[Optional[str], ...]:
return tuple(
[getattr(col, "name", None) for col in self.const.expressions]
)
@util.memoized_property
def is_named(self):
return True
@util.memoized_property
def unnamed(self):
return (self.is_unique,) + self.column_names_optional
class _fk_constraint_sig(_constraint_sig[ForeignKeyConstraint]):
_is_fk = True
@classmethod
def _register(cls) -> None:
_clsreg["foreign_key_constraint"] = cls
def __init__(
self,
is_metadata: bool,
impl: DefaultImpl,
const: ForeignKeyConstraint,
) -> None:
self._is_metadata = is_metadata
self.impl = impl
self.const = const
self.name = sqla_compat.constraint_name_or_none(const.name)
(
self.source_schema,
self.source_table,
self.source_columns,
self.target_schema,
self.target_table,
self.target_columns,
onupdate,
ondelete,
deferrable,
initially,
) = sqla_compat._fk_spec(const)
self._sig: Tuple[Any, ...] = (
self.source_schema,
self.source_table,
tuple(self.source_columns),
self.target_schema,
self.target_table,
tuple(self.target_columns),
) + (
(None if onupdate.lower() == "no action" else onupdate.lower())
if onupdate
else None,
(None if ondelete.lower() == "no action" else ondelete.lower())
if ondelete
else None,
# convert initially + deferrable into one three-state value
"initially_deferrable"
if initially and initially.lower() == "deferred"
else "deferrable"
if deferrable
else "not deferrable",
)
@util.memoized_property
def unnamed_no_options(self):
return (
self.source_schema,
self.source_table,
tuple(self.source_columns),
self.target_schema,
self.target_table,
tuple(self.target_columns),
)
def is_index_sig(sig: _constraint_sig) -> TypeGuard[_ix_constraint_sig]:
return sig._is_index
def is_uq_sig(sig: _constraint_sig) -> TypeGuard[_uq_constraint_sig]:
return sig._is_uq
def is_fk_sig(sig: _constraint_sig) -> TypeGuard[_fk_constraint_sig]:
return sig._is_fk

View File

@@ -0,0 +1,335 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import functools
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import exc
from sqlalchemy import Integer
from sqlalchemy import types as sqltypes
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import Column
from sqlalchemy.schema import DDLElement
from sqlalchemy.sql.elements import quoted_name
from ..util.sqla_compat import _columns_for_constraint # noqa
from ..util.sqla_compat import _find_columns # noqa
from ..util.sqla_compat import _fk_spec # noqa
from ..util.sqla_compat import _is_type_bound # noqa
from ..util.sqla_compat import _table_for_constraint # noqa
if TYPE_CHECKING:
from typing import Any
from sqlalchemy.sql.compiler import Compiled
from sqlalchemy.sql.compiler import DDLCompiler
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.functions import Function
from sqlalchemy.sql.schema import FetchedValue
from sqlalchemy.sql.type_api import TypeEngine
from .impl import DefaultImpl
from ..util.sqla_compat import Computed
from ..util.sqla_compat import Identity
_ServerDefault = Union["TextClause", "FetchedValue", "Function[Any]", str]
class AlterTable(DDLElement):
"""Represent an ALTER TABLE statement.
Only the string name and optional schema name of the table
is required, not a full Table object.
"""
def __init__(
self,
table_name: str,
schema: Optional[Union[quoted_name, str]] = None,
) -> None:
self.table_name = table_name
self.schema = schema
class RenameTable(AlterTable):
def __init__(
self,
old_table_name: str,
new_table_name: Union[quoted_name, str],
schema: Optional[Union[quoted_name, str]] = None,
) -> None:
super().__init__(old_table_name, schema=schema)
self.new_table_name = new_table_name
class AlterColumn(AlterTable):
def __init__(
self,
name: str,
column_name: str,
schema: Optional[str] = None,
existing_type: Optional[TypeEngine] = None,
existing_nullable: Optional[bool] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_comment: Optional[str] = None,
) -> None:
super().__init__(name, schema=schema)
self.column_name = column_name
self.existing_type = (
sqltypes.to_instance(existing_type)
if existing_type is not None
else None
)
self.existing_nullable = existing_nullable
self.existing_server_default = existing_server_default
self.existing_comment = existing_comment
class ColumnNullable(AlterColumn):
def __init__(
self, name: str, column_name: str, nullable: bool, **kw
) -> None:
super().__init__(name, column_name, **kw)
self.nullable = nullable
class ColumnType(AlterColumn):
def __init__(
self, name: str, column_name: str, type_: TypeEngine, **kw
) -> None:
super().__init__(name, column_name, **kw)
self.type_ = sqltypes.to_instance(type_)
class ColumnName(AlterColumn):
def __init__(
self, name: str, column_name: str, newname: str, **kw
) -> None:
super().__init__(name, column_name, **kw)
self.newname = newname
class ColumnDefault(AlterColumn):
def __init__(
self,
name: str,
column_name: str,
default: Optional[_ServerDefault],
**kw,
) -> None:
super().__init__(name, column_name, **kw)
self.default = default
class ComputedColumnDefault(AlterColumn):
def __init__(
self, name: str, column_name: str, default: Optional[Computed], **kw
) -> None:
super().__init__(name, column_name, **kw)
self.default = default
class IdentityColumnDefault(AlterColumn):
def __init__(
self,
name: str,
column_name: str,
default: Optional[Identity],
impl: DefaultImpl,
**kw,
) -> None:
super().__init__(name, column_name, **kw)
self.default = default
self.impl = impl
class AddColumn(AlterTable):
def __init__(
self,
name: str,
column: Column[Any],
schema: Optional[Union[quoted_name, str]] = None,
) -> None:
super().__init__(name, schema=schema)
self.column = column
class DropColumn(AlterTable):
def __init__(
self, name: str, column: Column[Any], schema: Optional[str] = None
) -> None:
super().__init__(name, schema=schema)
self.column = column
class ColumnComment(AlterColumn):
def __init__(
self, name: str, column_name: str, comment: Optional[str], **kw
) -> None:
super().__init__(name, column_name, **kw)
self.comment = comment
@compiles(RenameTable) # type: ignore[misc]
def visit_rename_table(
element: RenameTable, compiler: DDLCompiler, **kw
) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, element.schema),
)
@compiles(AddColumn) # type: ignore[misc]
def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
add_column(compiler, element.column, **kw),
)
@compiles(DropColumn) # type: ignore[misc]
def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
drop_column(compiler, element.column.name, **kw),
)
@compiles(ColumnNullable) # type: ignore[misc]
def visit_column_nullable(
element: ColumnNullable, compiler: DDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"DROP NOT NULL" if element.nullable else "SET NOT NULL",
)
@compiles(ColumnType) # type: ignore[misc]
def visit_column_type(element: ColumnType, compiler: DDLCompiler, **kw) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"TYPE %s" % format_type(compiler, element.type_),
)
@compiles(ColumnName) # type: ignore[misc]
def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str:
return "%s RENAME %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
format_column_name(compiler, element.newname),
)
@compiles(ColumnDefault) # type: ignore[misc]
def visit_column_default(
element: ColumnDefault, compiler: DDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"SET DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DROP DEFAULT",
)
@compiles(ComputedColumnDefault) # type: ignore[misc]
def visit_computed_column(
element: ComputedColumnDefault, compiler: DDLCompiler, **kw
):
raise exc.CompileError(
'Adding or removing a "computed" construct, e.g. GENERATED '
"ALWAYS AS, to or from an existing column is not supported."
)
@compiles(IdentityColumnDefault) # type: ignore[misc]
def visit_identity_column(
element: IdentityColumnDefault, compiler: DDLCompiler, **kw
):
raise exc.CompileError(
'Adding, removing or modifying an "identity" construct, '
"e.g. GENERATED AS IDENTITY, to or from an existing "
"column is not supported in this dialect."
)
def quote_dotted(
name: Union[quoted_name, str], quote: functools.partial
) -> Union[quoted_name, str]:
"""quote the elements of a dotted name"""
if isinstance(name, quoted_name):
return quote(name)
result = ".".join([quote(x) for x in name.split(".")])
return result
def format_table_name(
compiler: Compiled,
name: Union[quoted_name, str],
schema: Optional[Union[quoted_name, str]],
) -> Union[quoted_name, str]:
quote = functools.partial(compiler.preparer.quote)
if schema:
return quote_dotted(schema, quote) + "." + quote(name)
else:
return quote(name)
def format_column_name(
compiler: DDLCompiler, name: Optional[Union[quoted_name, str]]
) -> Union[quoted_name, str]:
return compiler.preparer.quote(name) # type: ignore[arg-type]
def format_server_default(
compiler: DDLCompiler,
default: Optional[_ServerDefault],
) -> str:
return compiler.get_column_default_string(
Column("x", Integer, server_default=default)
)
def format_type(compiler: DDLCompiler, type_: TypeEngine) -> str:
return compiler.dialect.type_compiler.process(type_)
def alter_table(
compiler: DDLCompiler,
name: str,
schema: Optional[str],
) -> str:
return "ALTER TABLE %s" % format_table_name(compiler, name, schema)
def drop_column(compiler: DDLCompiler, name: str, **kw) -> str:
return "DROP COLUMN %s" % format_column_name(compiler, name)
def alter_column(compiler: DDLCompiler, name: str) -> str:
return "ALTER COLUMN %s" % format_column_name(compiler, name)
def add_column(compiler: DDLCompiler, column: Column[Any], **kw) -> str:
text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
const = " ".join(
compiler.process(constraint) for constraint in column.constraints
)
if const:
text += " " + const
return text

View File

@@ -0,0 +1,844 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import logging
import re
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
from typing import Mapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import cast
from sqlalchemy import schema
from sqlalchemy import text
from . import _autogen
from . import base
from ._autogen import _constraint_sig as _constraint_sig
from ._autogen import ComparisonResult as ComparisonResult
from .. import util
from ..util import sqla_compat
if TYPE_CHECKING:
from typing import Literal
from typing import TextIO
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql import Executable
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.schema import UniqueConstraint
from sqlalchemy.sql.selectable import TableClause
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
from ..autogenerate.api import AutogenContext
from ..operations.batch import ApplyBatchImpl
from ..operations.batch import BatchOperationsImpl
log = logging.getLogger(__name__)
class ImplMeta(type):
def __init__(
cls,
classname: str,
bases: Tuple[Type[DefaultImpl]],
dict_: Dict[str, Any],
):
newtype = type.__init__(cls, classname, bases, dict_)
if "__dialect__" in dict_:
_impls[dict_["__dialect__"]] = cls # type: ignore[assignment]
return newtype
_impls: Dict[str, Type[DefaultImpl]] = {}
class DefaultImpl(metaclass=ImplMeta):
"""Provide the entrypoint for major migration operations,
including database-specific behavioral variances.
While individual SQL/DDL constructs already provide
for database-specific implementations, variances here
allow for entirely different sequences of operations
to take place for a particular migration, such as
SQL Server's special 'IDENTITY INSERT' step for
bulk inserts.
"""
__dialect__ = "default"
transactional_ddl = False
command_terminator = ";"
type_synonyms: Tuple[Set[str], ...] = ({"NUMERIC", "DECIMAL"},)
type_arg_extract: Sequence[str] = ()
# These attributes are deprecated in SQLAlchemy via #10247. They need to
# be ignored to support older version that did not use dialect kwargs.
# They only apply to Oracle and are replaced by oracle_order,
# oracle_on_null
identity_attrs_ignore: Tuple[str, ...] = ("order", "on_null")
def __init__(
self,
dialect: Dialect,
connection: Optional[Connection],
as_sql: bool,
transactional_ddl: Optional[bool],
output_buffer: Optional[TextIO],
context_opts: Dict[str, Any],
) -> None:
self.dialect = dialect
self.connection = connection
self.as_sql = as_sql
self.literal_binds = context_opts.get("literal_binds", False)
self.output_buffer = output_buffer
self.memo: dict = {}
self.context_opts = context_opts
if transactional_ddl is not None:
self.transactional_ddl = transactional_ddl
if self.literal_binds:
if not self.as_sql:
raise util.CommandError(
"Can't use literal_binds setting without as_sql mode"
)
@classmethod
def get_by_dialect(cls, dialect: Dialect) -> Type[DefaultImpl]:
return _impls[dialect.name]
def static_output(self, text: str) -> None:
assert self.output_buffer is not None
self.output_buffer.write(text + "\n\n")
self.output_buffer.flush()
def requires_recreate_in_batch(
self, batch_op: BatchOperationsImpl
) -> bool:
"""Return True if the given :class:`.BatchOperationsImpl`
would need the table to be recreated and copied in order to
proceed.
Normally, only returns True on SQLite when operations other
than add_column are present.
"""
return False
def prep_table_for_batch(
self, batch_impl: ApplyBatchImpl, table: Table
) -> None:
"""perform any operations needed on a table before a new
one is created to replace it in batch mode.
the PG dialect uses this to drop constraints on the table
before the new one uses those same names.
"""
@property
def bind(self) -> Optional[Connection]:
return self.connection
def _exec(
self,
construct: Union[Executable, str],
execution_options: Optional[dict[str, Any]] = None,
multiparams: Sequence[dict] = (),
params: Dict[str, Any] = util.immutabledict(),
) -> Optional[CursorResult]:
if isinstance(construct, str):
construct = text(construct)
if self.as_sql:
if multiparams or params:
# TODO: coverage
raise Exception("Execution arguments not allowed with as_sql")
compile_kw: dict[str, Any]
if self.literal_binds and not isinstance(
construct, schema.DDLElement
):
compile_kw = dict(compile_kwargs={"literal_binds": True})
else:
compile_kw = {}
if TYPE_CHECKING:
assert isinstance(construct, ClauseElement)
compiled = construct.compile(dialect=self.dialect, **compile_kw)
self.static_output(
str(compiled).replace("\t", " ").strip()
+ self.command_terminator
)
return None
else:
conn = self.connection
assert conn is not None
if execution_options:
conn = conn.execution_options(**execution_options)
if params:
assert isinstance(multiparams, tuple)
multiparams += (params,)
return conn.execute(construct, multiparams)
def execute(
self,
sql: Union[Executable, str],
execution_options: Optional[dict[str, Any]] = None,
) -> None:
self._exec(sql, execution_options)
def alter_column(
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Union[_ServerDefault, Literal[False]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
autoincrement: Optional[bool] = None,
comment: Optional[Union[str, Literal[False]]] = False,
existing_comment: Optional[str] = None,
existing_type: Optional[TypeEngine] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_nullable: Optional[bool] = None,
existing_autoincrement: Optional[bool] = None,
**kw: Any,
) -> None:
if autoincrement is not None or existing_autoincrement is not None:
util.warn(
"autoincrement and existing_autoincrement "
"only make sense for MySQL",
stacklevel=3,
)
if nullable is not None:
self._exec(
base.ColumnNullable(
table_name,
column_name,
nullable,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_comment=existing_comment,
)
)
if server_default is not False:
kw = {}
cls_: Type[
Union[
base.ComputedColumnDefault,
base.IdentityColumnDefault,
base.ColumnDefault,
]
]
if sqla_compat._server_default_is_computed(
server_default, existing_server_default
):
cls_ = base.ComputedColumnDefault
elif sqla_compat._server_default_is_identity(
server_default, existing_server_default
):
cls_ = base.IdentityColumnDefault
kw["impl"] = self
else:
cls_ = base.ColumnDefault
self._exec(
cls_(
table_name,
column_name,
server_default, # type:ignore[arg-type]
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_comment=existing_comment,
**kw,
)
)
if type_ is not None:
self._exec(
base.ColumnType(
table_name,
column_name,
type_,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_comment=existing_comment,
)
)
if comment is not False:
self._exec(
base.ColumnComment(
table_name,
column_name,
comment,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_comment=existing_comment,
)
)
# do the new name last ;)
if name is not None:
self._exec(
base.ColumnName(
table_name,
column_name,
name,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
)
)
def add_column(
self,
table_name: str,
column: Column[Any],
schema: Optional[Union[str, quoted_name]] = None,
) -> None:
self._exec(base.AddColumn(table_name, column, schema=schema))
def drop_column(
self,
table_name: str,
column: Column[Any],
schema: Optional[str] = None,
**kw,
) -> None:
self._exec(base.DropColumn(table_name, column, schema=schema))
def add_constraint(self, const: Any) -> None:
if const._create_rule is None or const._create_rule(self):
self._exec(schema.AddConstraint(const))
def drop_constraint(self, const: Constraint) -> None:
self._exec(schema.DropConstraint(const))
def rename_table(
self,
old_table_name: str,
new_table_name: Union[str, quoted_name],
schema: Optional[Union[str, quoted_name]] = None,
) -> None:
self._exec(
base.RenameTable(old_table_name, new_table_name, schema=schema)
)
def create_table(self, table: Table) -> None:
table.dispatch.before_create(
table, self.connection, checkfirst=False, _ddl_runner=self
)
self._exec(schema.CreateTable(table))
table.dispatch.after_create(
table, self.connection, checkfirst=False, _ddl_runner=self
)
for index in table.indexes:
self._exec(schema.CreateIndex(index))
with_comment = (
self.dialect.supports_comments and not self.dialect.inline_comments
)
comment = table.comment
if comment and with_comment:
self.create_table_comment(table)
for column in table.columns:
comment = column.comment
if comment and with_comment:
self.create_column_comment(column)
def drop_table(self, table: Table) -> None:
table.dispatch.before_drop(
table, self.connection, checkfirst=False, _ddl_runner=self
)
self._exec(schema.DropTable(table))
table.dispatch.after_drop(
table, self.connection, checkfirst=False, _ddl_runner=self
)
def create_index(self, index: Index, **kw: Any) -> None:
self._exec(schema.CreateIndex(index, **kw))
def create_table_comment(self, table: Table) -> None:
self._exec(schema.SetTableComment(table))
def drop_table_comment(self, table: Table) -> None:
self._exec(schema.DropTableComment(table))
def create_column_comment(self, column: ColumnElement[Any]) -> None:
self._exec(schema.SetColumnComment(column))
def drop_index(self, index: Index, **kw: Any) -> None:
self._exec(schema.DropIndex(index, **kw))
def bulk_insert(
self,
table: Union[TableClause, Table],
rows: List[dict],
multiinsert: bool = True,
) -> None:
if not isinstance(rows, list):
raise TypeError("List expected")
elif rows and not isinstance(rows[0], dict):
raise TypeError("List of dictionaries expected")
if self.as_sql:
for row in rows:
self._exec(
sqla_compat._insert_inline(table).values(
**{
k: sqla_compat._literal_bindparam(
k, v, type_=table.c[k].type
)
if not isinstance(
v, sqla_compat._literal_bindparam
)
else v
for k, v in row.items()
}
)
)
else:
if rows:
if multiinsert:
self._exec(
sqla_compat._insert_inline(table), multiparams=rows
)
else:
for row in rows:
self._exec(
sqla_compat._insert_inline(table).values(**row)
)
def _tokenize_column_type(self, column: Column) -> Params:
definition: str
definition = self.dialect.type_compiler.process(column.type).lower()
# tokenize the SQLAlchemy-generated version of a type, so that
# the two can be compared.
#
# examples:
# NUMERIC(10, 5)
# TIMESTAMP WITH TIMEZONE
# INTEGER UNSIGNED
# INTEGER (10) UNSIGNED
# INTEGER(10) UNSIGNED
# varchar character set utf8
#
tokens: List[str] = re.findall(r"[\w\-_]+|\(.+?\)", definition)
term_tokens: List[str] = []
paren_term = None
for token in tokens:
if re.match(r"^\(.*\)$", token):
paren_term = token
else:
term_tokens.append(token)
params = Params(term_tokens[0], term_tokens[1:], [], {})
if paren_term:
term: str
for term in re.findall("[^(),]+", paren_term):
if "=" in term:
key, val = term.split("=")
params.kwargs[key.strip()] = val.strip()
else:
params.args.append(term.strip())
return params
def _column_types_match(
self, inspector_params: Params, metadata_params: Params
) -> bool:
if inspector_params.token0 == metadata_params.token0:
return True
synonyms = [{t.lower() for t in batch} for batch in self.type_synonyms]
inspector_all_terms = " ".join(
[inspector_params.token0] + inspector_params.tokens
)
metadata_all_terms = " ".join(
[metadata_params.token0] + metadata_params.tokens
)
for batch in synonyms:
if {inspector_all_terms, metadata_all_terms}.issubset(batch) or {
inspector_params.token0,
metadata_params.token0,
}.issubset(batch):
return True
return False
def _column_args_match(
self, inspected_params: Params, meta_params: Params
) -> bool:
"""We want to compare column parameters. However, we only want
to compare parameters that are set. If they both have `collation`,
we want to make sure they are the same. However, if only one
specifies it, dont flag it for being less specific
"""
if (
len(meta_params.tokens) == len(inspected_params.tokens)
and meta_params.tokens != inspected_params.tokens
):
return False
if (
len(meta_params.args) == len(inspected_params.args)
and meta_params.args != inspected_params.args
):
return False
insp = " ".join(inspected_params.tokens).lower()
meta = " ".join(meta_params.tokens).lower()
for reg in self.type_arg_extract:
mi = re.search(reg, insp)
mm = re.search(reg, meta)
if mi and mm and mi.group(1) != mm.group(1):
return False
return True
def compare_type(
self, inspector_column: Column[Any], metadata_column: Column
) -> bool:
"""Returns True if there ARE differences between the types of the two
columns. Takes impl.type_synonyms into account between retrospected
and metadata types
"""
inspector_params = self._tokenize_column_type(inspector_column)
metadata_params = self._tokenize_column_type(metadata_column)
if not self._column_types_match(inspector_params, metadata_params):
return True
if not self._column_args_match(inspector_params, metadata_params):
return True
return False
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
return rendered_inspector_default != rendered_metadata_default
def correct_for_autogen_constraints(
self,
conn_uniques: Set[UniqueConstraint],
conn_indexes: Set[Index],
metadata_unique_constraints: Set[UniqueConstraint],
metadata_indexes: Set[Index],
) -> None:
pass
def cast_for_batch_migrate(self, existing, existing_transfer, new_type):
if existing.type._type_affinity is not new_type._type_affinity:
existing_transfer["expr"] = cast(
existing_transfer["expr"], new_type
)
def render_ddl_sql_expr(
self, expr: ClauseElement, is_server_default: bool = False, **kw: Any
) -> str:
"""Render a SQL expression that is typically a server default,
index expression, etc.
"""
compile_kw = {"literal_binds": True, "include_table": False}
return str(
expr.compile(dialect=self.dialect, compile_kwargs=compile_kw)
)
def _compat_autogen_column_reflect(self, inspector: Inspector) -> Callable:
return self.autogen_column_reflect
def correct_for_autogen_foreignkeys(
self,
conn_fks: Set[ForeignKeyConstraint],
metadata_fks: Set[ForeignKeyConstraint],
) -> None:
pass
def autogen_column_reflect(self, inspector, table, column_info):
"""A hook that is attached to the 'column_reflect' event for when
a Table is reflected from the database during the autogenerate
process.
Dialects can elect to modify the information gathered here.
"""
def start_migrations(self) -> None:
"""A hook called when :meth:`.EnvironmentContext.run_migrations`
is called.
Implementations can set up per-migration-run state here.
"""
def emit_begin(self) -> None:
"""Emit the string ``BEGIN``, or the backend-specific
equivalent, on the current connection context.
This is used in offline mode and typically
via :meth:`.EnvironmentContext.begin_transaction`.
"""
self.static_output("BEGIN" + self.command_terminator)
def emit_commit(self) -> None:
"""Emit the string ``COMMIT``, or the backend-specific
equivalent, on the current connection context.
This is used in offline mode and typically
via :meth:`.EnvironmentContext.begin_transaction`.
"""
self.static_output("COMMIT" + self.command_terminator)
def render_type(
self, type_obj: TypeEngine, autogen_context: AutogenContext
) -> Union[str, Literal[False]]:
return False
def _compare_identity_default(self, metadata_identity, inspector_identity):
# ignored contains the attributes that were not considered
# because assumed to their default values in the db.
diff, ignored = _compare_identity_options(
metadata_identity,
inspector_identity,
sqla_compat.Identity(),
skip={"always"},
)
meta_always = getattr(metadata_identity, "always", None)
inspector_always = getattr(inspector_identity, "always", None)
# None and False are the same in this comparison
if bool(meta_always) != bool(inspector_always):
diff.add("always")
diff.difference_update(self.identity_attrs_ignore)
# returns 3 values:
return (
# different identity attributes
diff,
# ignored identity attributes
ignored,
# if the two identity should be considered different
bool(diff) or bool(metadata_identity) != bool(inspector_identity),
)
def _compare_index_unique(
self, metadata_index: Index, reflected_index: Index
) -> Optional[str]:
conn_unique = bool(reflected_index.unique)
meta_unique = bool(metadata_index.unique)
if conn_unique != meta_unique:
return f"unique={conn_unique} to unique={meta_unique}"
else:
return None
def _create_metadata_constraint_sig(
self, constraint: _autogen._C, **opts: Any
) -> _constraint_sig[_autogen._C]:
return _constraint_sig.from_constraint(True, self, constraint, **opts)
def _create_reflected_constraint_sig(
self, constraint: _autogen._C, **opts: Any
) -> _constraint_sig[_autogen._C]:
return _constraint_sig.from_constraint(False, self, constraint, **opts)
def compare_indexes(
self,
metadata_index: Index,
reflected_index: Index,
) -> ComparisonResult:
"""Compare two indexes by comparing the signature generated by
``create_index_sig``.
This method returns a ``ComparisonResult``.
"""
msg: List[str] = []
unique_msg = self._compare_index_unique(
metadata_index, reflected_index
)
if unique_msg:
msg.append(unique_msg)
m_sig = self._create_metadata_constraint_sig(metadata_index)
r_sig = self._create_reflected_constraint_sig(reflected_index)
assert _autogen.is_index_sig(m_sig)
assert _autogen.is_index_sig(r_sig)
# The assumption is that the index have no expression
for sig in m_sig, r_sig:
if sig.has_expressions:
log.warning(
"Generating approximate signature for index %s. "
"The dialect "
"implementation should either skip expression indexes "
"or provide a custom implementation.",
sig.const,
)
if m_sig.column_names != r_sig.column_names:
msg.append(
f"expression {r_sig.column_names} to {m_sig.column_names}"
)
if msg:
return ComparisonResult.Different(msg)
else:
return ComparisonResult.Equal()
def compare_unique_constraint(
self,
metadata_constraint: UniqueConstraint,
reflected_constraint: UniqueConstraint,
) -> ComparisonResult:
"""Compare two unique constraints by comparing the two signatures.
The arguments are two tuples that contain the unique constraint and
the signatures generated by ``create_unique_constraint_sig``.
This method returns a ``ComparisonResult``.
"""
metadata_tup = self._create_metadata_constraint_sig(
metadata_constraint
)
reflected_tup = self._create_reflected_constraint_sig(
reflected_constraint
)
meta_sig = metadata_tup.unnamed
conn_sig = reflected_tup.unnamed
if conn_sig != meta_sig:
return ComparisonResult.Different(
f"expression {conn_sig} to {meta_sig}"
)
else:
return ComparisonResult.Equal()
def _skip_functional_indexes(self, metadata_indexes, conn_indexes):
conn_indexes_by_name = {c.name: c for c in conn_indexes}
for idx in list(metadata_indexes):
if idx.name in conn_indexes_by_name:
continue
iex = sqla_compat.is_expression_index(idx)
if iex:
util.warn(
"autogenerate skipping metadata-specified "
"expression-based index "
f"{idx.name!r}; dialect {self.__dialect__!r} under "
f"SQLAlchemy {sqla_compat.sqlalchemy_version} can't "
"reflect these indexes so they can't be compared"
)
metadata_indexes.discard(idx)
def adjust_reflected_dialect_options(
self, reflected_object: Dict[str, Any], kind: str
) -> Dict[str, Any]:
return reflected_object.get("dialect_options", {})
class Params(NamedTuple):
token0: str
tokens: List[str]
args: List[str]
kwargs: Dict[str, str]
def _compare_identity_options(
metadata_io: Union[schema.Identity, schema.Sequence, None],
inspector_io: Union[schema.Identity, schema.Sequence, None],
default_io: Union[schema.Identity, schema.Sequence],
skip: Set[str],
):
# this can be used for identity or sequence compare.
# default_io is an instance of IdentityOption with all attributes to the
# default value.
meta_d = sqla_compat._get_identity_options_dict(metadata_io)
insp_d = sqla_compat._get_identity_options_dict(inspector_io)
diff = set()
ignored_attr = set()
def check_dicts(
meta_dict: Mapping[str, Any],
insp_dict: Mapping[str, Any],
default_dict: Mapping[str, Any],
attrs: Iterable[str],
):
for attr in set(attrs).difference(skip):
meta_value = meta_dict.get(attr)
insp_value = insp_dict.get(attr)
if insp_value != meta_value:
default_value = default_dict.get(attr)
if meta_value == default_value:
ignored_attr.add(attr)
else:
diff.add(attr)
check_dicts(
meta_d,
insp_d,
sqla_compat._get_identity_options_dict(default_io),
set(meta_d).union(insp_d),
)
if sqla_compat.identity_has_dialect_kwargs:
# use only the dialect kwargs in inspector_io since metadata_io
# can have options for many backends
check_dicts(
getattr(metadata_io, "dialect_kwargs", {}),
getattr(inspector_io, "dialect_kwargs", {}),
default_io.dialect_kwargs, # type: ignore[union-attr]
getattr(inspector_io, "dialect_kwargs", {}),
)
return diff, ignored_attr

View File

@@ -0,0 +1,419 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import re
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import types as sqltypes
from sqlalchemy.schema import Column
from sqlalchemy.schema import CreateIndex
from sqlalchemy.sql.base import Executable
from sqlalchemy.sql.elements import ClauseElement
from .base import AddColumn
from .base import alter_column
from .base import alter_table
from .base import ColumnDefault
from .base import ColumnName
from .base import ColumnNullable
from .base import ColumnType
from .base import format_column_name
from .base import format_server_default
from .base import format_table_name
from .base import format_type
from .base import RenameTable
from .impl import DefaultImpl
from .. import util
from ..util import sqla_compat
from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy.dialects.mssql.base import MSDDLCompiler
from sqlalchemy.dialects.mssql.base import MSSQLCompiler
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.selectable import TableClause
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
class MSSQLImpl(DefaultImpl):
__dialect__ = "mssql"
transactional_ddl = True
batch_separator = "GO"
type_synonyms = DefaultImpl.type_synonyms + ({"VARCHAR", "NVARCHAR"},)
identity_attrs_ignore = DefaultImpl.identity_attrs_ignore + (
"minvalue",
"maxvalue",
"nominvalue",
"nomaxvalue",
"cycle",
"cache",
)
def __init__(self, *arg, **kw) -> None:
super().__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
"mssql_batch_separator", self.batch_separator
)
def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]:
result = super()._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
return result
def emit_begin(self) -> None:
self.static_output("BEGIN TRANSACTION" + self.command_terminator)
def emit_commit(self) -> None:
super().emit_commit()
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
def alter_column( # type:ignore[override]
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Optional[
Union[_ServerDefault, Literal[False]]
] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
existing_type: Optional[TypeEngine] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_nullable: Optional[bool] = None,
**kw: Any,
) -> None:
if nullable is not None:
if type_ is not None:
# the NULL/NOT NULL alter will handle
# the type alteration
existing_type = type_
type_ = None
elif existing_type is None:
raise util.CommandError(
"MS-SQL ALTER COLUMN operations "
"with NULL or NOT NULL require the "
"existing_type or a new type_ be passed."
)
elif existing_nullable is not None and type_ is not None:
nullable = existing_nullable
# the NULL/NOT NULL alter will handle
# the type alteration
existing_type = type_
type_ = None
elif type_ is not None:
util.warn(
"MS-SQL ALTER COLUMN operations that specify type_= "
"should also specify a nullable= or "
"existing_nullable= argument to avoid implicit conversion "
"of NOT NULL columns to NULL."
)
used_default = False
if sqla_compat._server_default_is_identity(
server_default, existing_server_default
) or sqla_compat._server_default_is_computed(
server_default, existing_server_default
):
used_default = True
kw["server_default"] = server_default
kw["existing_server_default"] = existing_server_default
super().alter_column(
table_name,
column_name,
nullable=nullable,
type_=type_,
schema=schema,
existing_type=existing_type,
existing_nullable=existing_nullable,
**kw,
)
if server_default is not False and used_default is False:
if existing_server_default is not False or server_default is None:
self._exec(
_ExecDropConstraint(
table_name,
column_name,
"sys.default_constraints",
schema,
)
)
if server_default is not None:
super().alter_column(
table_name,
column_name,
schema=schema,
server_default=server_default,
)
if name is not None:
super().alter_column(
table_name, column_name, schema=schema, name=name
)
def create_index(self, index: Index, **kw: Any) -> None:
# this likely defaults to None if not present, so get()
# should normally not return the default value. being
# defensive in any case
mssql_include = index.kwargs.get("mssql_include", None) or ()
assert index.table is not None
for col in mssql_include:
if col not in index.table.c:
index.table.append_column(Column(col, sqltypes.NullType))
self._exec(CreateIndex(index, **kw))
def bulk_insert( # type:ignore[override]
self, table: Union[TableClause, Table], rows: List[dict], **kw: Any
) -> None:
if self.as_sql:
self._exec(
"SET IDENTITY_INSERT %s ON"
% self.dialect.identifier_preparer.format_table(table)
)
super().bulk_insert(table, rows, **kw)
self._exec(
"SET IDENTITY_INSERT %s OFF"
% self.dialect.identifier_preparer.format_table(table)
)
else:
super().bulk_insert(table, rows, **kw)
def drop_column(
self,
table_name: str,
column: Column[Any],
schema: Optional[str] = None,
**kw,
) -> None:
drop_default = kw.pop("mssql_drop_default", False)
if drop_default:
self._exec(
_ExecDropConstraint(
table_name, column, "sys.default_constraints", schema
)
)
drop_check = kw.pop("mssql_drop_check", False)
if drop_check:
self._exec(
_ExecDropConstraint(
table_name, column, "sys.check_constraints", schema
)
)
drop_fks = kw.pop("mssql_drop_foreign_key", False)
if drop_fks:
self._exec(_ExecDropFKConstraint(table_name, column, schema))
super().drop_column(table_name, column, schema=schema, **kw)
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
if rendered_metadata_default is not None:
rendered_metadata_default = re.sub(
r"[\(\) \"\']", "", rendered_metadata_default
)
if rendered_inspector_default is not None:
# SQL Server collapses whitespace and adds arbitrary parenthesis
# within expressions. our only option is collapse all of it
rendered_inspector_default = re.sub(
r"[\(\) \"\']", "", rendered_inspector_default
)
return rendered_inspector_default != rendered_metadata_default
def _compare_identity_default(self, metadata_identity, inspector_identity):
diff, ignored, is_alter = super()._compare_identity_default(
metadata_identity, inspector_identity
)
if (
metadata_identity is None
and inspector_identity is not None
and not diff
and inspector_identity.column is not None
and inspector_identity.column.primary_key
):
# mssql reflect primary keys with autoincrement as identity
# columns. if no different attributes are present ignore them
is_alter = False
return diff, ignored, is_alter
def adjust_reflected_dialect_options(
self, reflected_object: Dict[str, Any], kind: str
) -> Dict[str, Any]:
options: Dict[str, Any]
options = reflected_object.get("dialect_options", {}).copy()
if not options.get("mssql_include"):
options.pop("mssql_include", None)
if not options.get("mssql_clustered"):
options.pop("mssql_clustered", None)
return options
class _ExecDropConstraint(Executable, ClauseElement):
inherit_cache = False
def __init__(
self,
tname: str,
colname: Union[Column[Any], str],
type_: str,
schema: Optional[str],
) -> None:
self.tname = tname
self.colname = colname
self.type_ = type_
self.schema = schema
class _ExecDropFKConstraint(Executable, ClauseElement):
inherit_cache = False
def __init__(
self, tname: str, colname: Column[Any], schema: Optional[str]
) -> None:
self.tname = tname
self.colname = colname
self.schema = schema
@compiles(_ExecDropConstraint, "mssql")
def _exec_drop_col_constraint(
element: _ExecDropConstraint, compiler: MSSQLCompiler, **kw
) -> str:
schema, tname, colname, type_ = (
element.schema,
element.tname,
element.colname,
element.type_,
)
# from http://www.mssqltips.com/sqlservertip/1425/\
# working-with-default-constraints-in-sql-server/
return """declare @const_name varchar(256)
select @const_name = QUOTENAME([name]) from %(type)s
where parent_object_id = object_id('%(schema_dot)s%(tname)s')
and col_name(parent_object_id, parent_column_id) = '%(colname)s'
exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
"type": type_,
"tname": tname,
"colname": colname,
"tname_quoted": format_table_name(compiler, tname, schema),
"schema_dot": schema + "." if schema else "",
}
@compiles(_ExecDropFKConstraint, "mssql")
def _exec_drop_col_fk_constraint(
element: _ExecDropFKConstraint, compiler: MSSQLCompiler, **kw
) -> str:
schema, tname, colname = element.schema, element.tname, element.colname
return """declare @const_name varchar(256)
select @const_name = QUOTENAME([name]) from
sys.foreign_keys fk join sys.foreign_key_columns fkc
on fk.object_id=fkc.constraint_object_id
where fkc.parent_object_id = object_id('%(schema_dot)s%(tname)s')
and col_name(fkc.parent_object_id, fkc.parent_column_id) = '%(colname)s'
exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
"tname": tname,
"colname": colname,
"tname_quoted": format_table_name(compiler, tname, schema),
"schema_dot": schema + "." if schema else "",
}
@compiles(AddColumn, "mssql")
def visit_add_column(element: AddColumn, compiler: MSDDLCompiler, **kw) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
mssql_add_column(compiler, element.column, **kw),
)
def mssql_add_column(
compiler: MSDDLCompiler, column: Column[Any], **kw
) -> str:
return "ADD %s" % compiler.get_column_specification(column, **kw)
@compiles(ColumnNullable, "mssql")
def visit_column_nullable(
element: ColumnNullable, compiler: MSDDLCompiler, **kw
) -> str:
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
format_type(compiler, element.existing_type), # type: ignore[arg-type]
"NULL" if element.nullable else "NOT NULL",
)
@compiles(ColumnDefault, "mssql")
def visit_column_default(
element: ColumnDefault, compiler: MSDDLCompiler, **kw
) -> str:
# TODO: there can also be a named constraint
# with ADD CONSTRAINT here
return "%s ADD DEFAULT %s FOR %s" % (
alter_table(compiler, element.table_name, element.schema),
format_server_default(compiler, element.default),
format_column_name(compiler, element.column_name),
)
@compiles(ColumnName, "mssql")
def visit_rename_column(
element: ColumnName, compiler: MSDDLCompiler, **kw
) -> str:
return "EXEC sp_rename '%s.%s', %s, 'COLUMN'" % (
format_table_name(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
format_column_name(compiler, element.newname),
)
@compiles(ColumnType, "mssql")
def visit_column_type(
element: ColumnType, compiler: MSDDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
format_type(compiler, element.type_),
)
@compiles(RenameTable, "mssql")
def visit_rename_table(
element: RenameTable, compiler: MSDDLCompiler, **kw
) -> str:
return "EXEC sp_rename '%s', %s" % (
format_table_name(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)

View File

@@ -0,0 +1,474 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import re
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import schema
from sqlalchemy import types as sqltypes
from .base import alter_table
from .base import AlterColumn
from .base import ColumnDefault
from .base import ColumnName
from .base import ColumnNullable
from .base import ColumnType
from .base import format_column_name
from .base import format_server_default
from .impl import DefaultImpl
from .. import util
from ..util import sqla_compat
from ..util.sqla_compat import _is_mariadb
from ..util.sqla_compat import _is_type_bound
from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler
from sqlalchemy.sql.ddl import DropConstraint
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
class MySQLImpl(DefaultImpl):
__dialect__ = "mysql"
transactional_ddl = False
type_synonyms = DefaultImpl.type_synonyms + (
{"BOOL", "TINYINT"},
{"JSON", "LONGTEXT"},
)
type_arg_extract = [r"character set ([\w\-_]+)", r"collate ([\w\-_]+)"]
def alter_column( # type:ignore[override]
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Union[_ServerDefault, Literal[False]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
existing_type: Optional[TypeEngine] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_nullable: Optional[bool] = None,
autoincrement: Optional[bool] = None,
existing_autoincrement: Optional[bool] = None,
comment: Optional[Union[str, Literal[False]]] = False,
existing_comment: Optional[str] = None,
**kw: Any,
) -> None:
if sqla_compat._server_default_is_identity(
server_default, existing_server_default
) or sqla_compat._server_default_is_computed(
server_default, existing_server_default
):
# modifying computed or identity columns is not supported
# the default will raise
super().alter_column(
table_name,
column_name,
nullable=nullable,
type_=type_,
schema=schema,
existing_type=existing_type,
existing_nullable=existing_nullable,
server_default=server_default,
existing_server_default=existing_server_default,
**kw,
)
if name is not None or self._is_mysql_allowed_functional_default(
type_ if type_ is not None else existing_type, server_default
):
self._exec(
MySQLChangeColumn(
table_name,
column_name,
schema=schema,
newname=name if name is not None else column_name,
nullable=nullable
if nullable is not None
else existing_nullable
if existing_nullable is not None
else True,
type_=type_ if type_ is not None else existing_type,
default=server_default
if server_default is not False
else existing_server_default,
autoincrement=autoincrement
if autoincrement is not None
else existing_autoincrement,
comment=comment
if comment is not False
else existing_comment,
)
)
elif (
nullable is not None
or type_ is not None
or autoincrement is not None
or comment is not False
):
self._exec(
MySQLModifyColumn(
table_name,
column_name,
schema=schema,
newname=name if name is not None else column_name,
nullable=nullable
if nullable is not None
else existing_nullable
if existing_nullable is not None
else True,
type_=type_ if type_ is not None else existing_type,
default=server_default
if server_default is not False
else existing_server_default,
autoincrement=autoincrement
if autoincrement is not None
else existing_autoincrement,
comment=comment
if comment is not False
else existing_comment,
)
)
elif server_default is not False:
self._exec(
MySQLAlterDefault(
table_name, column_name, server_default, schema=schema
)
)
def drop_constraint(
self,
const: Constraint,
) -> None:
if isinstance(const, schema.CheckConstraint) and _is_type_bound(const):
return
super().drop_constraint(const)
def _is_mysql_allowed_functional_default(
self,
type_: Optional[TypeEngine],
server_default: Union[_ServerDefault, Literal[False]],
) -> bool:
return (
type_ is not None
and type_._type_affinity is sqltypes.DateTime
and server_default is not None
)
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
# partially a workaround for SQLAlchemy issue #3023; if the
# column were created without "NOT NULL", MySQL may have added
# an implicit default of '0' which we need to skip
# TODO: this is not really covered anymore ?
if (
metadata_column.type._type_affinity is sqltypes.Integer
and inspector_column.primary_key
and not inspector_column.autoincrement
and not rendered_metadata_default
and rendered_inspector_default == "'0'"
):
return False
elif (
rendered_inspector_default
and inspector_column.type._type_affinity is sqltypes.Integer
):
rendered_inspector_default = (
re.sub(r"^'|'$", "", rendered_inspector_default)
if rendered_inspector_default is not None
else None
)
return rendered_inspector_default != rendered_metadata_default
elif (
rendered_metadata_default
and metadata_column.type._type_affinity is sqltypes.String
):
metadata_default = re.sub(r"^'|'$", "", rendered_metadata_default)
return rendered_inspector_default != f"'{metadata_default}'"
elif rendered_inspector_default and rendered_metadata_default:
# adjust for "function()" vs. "FUNCTION" as can occur particularly
# for the CURRENT_TIMESTAMP function on newer MariaDB versions
# SQLAlchemy MySQL dialect bundles ON UPDATE into the server
# default; adjust for this possibly being present.
onupdate_ins = re.match(
r"(.*) (on update.*?)(?:\(\))?$",
rendered_inspector_default.lower(),
)
onupdate_met = re.match(
r"(.*) (on update.*?)(?:\(\))?$",
rendered_metadata_default.lower(),
)
if onupdate_ins:
if not onupdate_met:
return True
elif onupdate_ins.group(2) != onupdate_met.group(2):
return True
rendered_inspector_default = onupdate_ins.group(1)
rendered_metadata_default = onupdate_met.group(1)
return re.sub(
r"(.*?)(?:\(\))?$", r"\1", rendered_inspector_default.lower()
) != re.sub(
r"(.*?)(?:\(\))?$", r"\1", rendered_metadata_default.lower()
)
else:
return rendered_inspector_default != rendered_metadata_default
def correct_for_autogen_constraints(
self,
conn_unique_constraints,
conn_indexes,
metadata_unique_constraints,
metadata_indexes,
):
# TODO: if SQLA 1.0, make use of "duplicates_index"
# metadata
removed = set()
for idx in list(conn_indexes):
if idx.unique:
continue
# MySQL puts implicit indexes on FK columns, even if
# composite and even if MyISAM, so can't check this too easily.
# the name of the index may be the column name or it may
# be the name of the FK constraint.
for col in idx.columns:
if idx.name == col.name:
conn_indexes.remove(idx)
removed.add(idx.name)
break
for fk in col.foreign_keys:
if fk.name == idx.name:
conn_indexes.remove(idx)
removed.add(idx.name)
break
if idx.name in removed:
break
# then remove indexes from the "metadata_indexes"
# that we've removed from reflected, otherwise they come out
# as adds (see #202)
for idx in list(metadata_indexes):
if idx.name in removed:
metadata_indexes.remove(idx)
def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
conn_fk_by_sig = {
self._create_reflected_constraint_sig(fk).unnamed_no_options: fk
for fk in conn_fks
}
metadata_fk_by_sig = {
self._create_metadata_constraint_sig(fk).unnamed_no_options: fk
for fk in metadata_fks
}
for sig in set(conn_fk_by_sig).intersection(metadata_fk_by_sig):
mdfk = metadata_fk_by_sig[sig]
cnfk = conn_fk_by_sig[sig]
# MySQL considers RESTRICT to be the default and doesn't
# report on it. if the model has explicit RESTRICT and
# the conn FK has None, set it to RESTRICT
if (
mdfk.ondelete is not None
and mdfk.ondelete.lower() == "restrict"
and cnfk.ondelete is None
):
cnfk.ondelete = "RESTRICT"
if (
mdfk.onupdate is not None
and mdfk.onupdate.lower() == "restrict"
and cnfk.onupdate is None
):
cnfk.onupdate = "RESTRICT"
class MariaDBImpl(MySQLImpl):
__dialect__ = "mariadb"
class MySQLAlterDefault(AlterColumn):
def __init__(
self,
name: str,
column_name: str,
default: _ServerDefault,
schema: Optional[str] = None,
) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.default = default
class MySQLChangeColumn(AlterColumn):
def __init__(
self,
name: str,
column_name: str,
schema: Optional[str] = None,
newname: Optional[str] = None,
type_: Optional[TypeEngine] = None,
nullable: Optional[bool] = None,
default: Optional[Union[_ServerDefault, Literal[False]]] = False,
autoincrement: Optional[bool] = None,
comment: Optional[Union[str, Literal[False]]] = False,
) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.nullable = nullable
self.newname = newname
self.default = default
self.autoincrement = autoincrement
self.comment = comment
if type_ is None:
raise util.CommandError(
"All MySQL CHANGE/MODIFY COLUMN operations "
"require the existing type."
)
self.type_ = sqltypes.to_instance(type_)
class MySQLModifyColumn(MySQLChangeColumn):
pass
@compiles(ColumnNullable, "mysql", "mariadb")
@compiles(ColumnName, "mysql", "mariadb")
@compiles(ColumnDefault, "mysql", "mariadb")
@compiles(ColumnType, "mysql", "mariadb")
def _mysql_doesnt_support_individual(element, compiler, **kw):
raise NotImplementedError(
"Individual alter column constructs not supported by MySQL"
)
@compiles(MySQLAlterDefault, "mysql", "mariadb")
def _mysql_alter_default(
element: MySQLAlterDefault, compiler: MySQLDDLCompiler, **kw
) -> str:
return "%s ALTER COLUMN %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
"SET DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DROP DEFAULT",
)
@compiles(MySQLModifyColumn, "mysql", "mariadb")
def _mysql_modify_column(
element: MySQLModifyColumn, compiler: MySQLDDLCompiler, **kw
) -> str:
return "%s MODIFY %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
_mysql_colspec(
compiler,
nullable=element.nullable,
server_default=element.default,
type_=element.type_,
autoincrement=element.autoincrement,
comment=element.comment,
),
)
@compiles(MySQLChangeColumn, "mysql", "mariadb")
def _mysql_change_column(
element: MySQLChangeColumn, compiler: MySQLDDLCompiler, **kw
) -> str:
return "%s CHANGE %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
format_column_name(compiler, element.newname),
_mysql_colspec(
compiler,
nullable=element.nullable,
server_default=element.default,
type_=element.type_,
autoincrement=element.autoincrement,
comment=element.comment,
),
)
def _mysql_colspec(
compiler: MySQLDDLCompiler,
nullable: Optional[bool],
server_default: Optional[Union[_ServerDefault, Literal[False]]],
type_: TypeEngine,
autoincrement: Optional[bool],
comment: Optional[Union[str, Literal[False]]],
) -> str:
spec = "%s %s" % (
compiler.dialect.type_compiler.process(type_),
"NULL" if nullable else "NOT NULL",
)
if autoincrement:
spec += " AUTO_INCREMENT"
if server_default is not False and server_default is not None:
spec += " DEFAULT %s" % format_server_default(compiler, server_default)
if comment:
spec += " COMMENT %s" % compiler.sql_compiler.render_literal_value(
comment, sqltypes.String()
)
return spec
@compiles(schema.DropConstraint, "mysql", "mariadb")
def _mysql_drop_constraint(
element: DropConstraint, compiler: MySQLDDLCompiler, **kw
) -> str:
"""Redefine SQLAlchemy's drop constraint to
raise errors for invalid constraint type."""
constraint = element.element
if isinstance(
constraint,
(
schema.ForeignKeyConstraint,
schema.PrimaryKeyConstraint,
schema.UniqueConstraint,
),
):
assert not kw
return compiler.visit_drop_constraint(element)
elif isinstance(constraint, schema.CheckConstraint):
# note that SQLAlchemy as of 1.2 does not yet support
# DROP CONSTRAINT for MySQL/MariaDB, so we implement fully
# here.
if _is_mariadb(compiler.dialect):
return "ALTER TABLE %s DROP CONSTRAINT %s" % (
compiler.preparer.format_table(constraint.table),
compiler.preparer.format_constraint(constraint),
)
else:
return "ALTER TABLE %s DROP CHECK %s" % (
compiler.preparer.format_table(constraint.table),
compiler.preparer.format_constraint(constraint),
)
else:
raise NotImplementedError(
"No generic 'DROP CONSTRAINT' in MySQL - "
"please specify constraint type"
)

View File

@@ -0,0 +1,200 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import re
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from sqlalchemy.sql import sqltypes
from .base import AddColumn
from .base import alter_table
from .base import ColumnComment
from .base import ColumnDefault
from .base import ColumnName
from .base import ColumnNullable
from .base import ColumnType
from .base import format_column_name
from .base import format_server_default
from .base import format_table_name
from .base import format_type
from .base import IdentityColumnDefault
from .base import RenameTable
from .impl import DefaultImpl
from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.sql.schema import Column
class OracleImpl(DefaultImpl):
__dialect__ = "oracle"
transactional_ddl = False
batch_separator = "/"
command_terminator = ""
type_synonyms = DefaultImpl.type_synonyms + (
{"VARCHAR", "VARCHAR2"},
{"BIGINT", "INTEGER", "SMALLINT", "DECIMAL", "NUMERIC", "NUMBER"},
{"DOUBLE", "FLOAT", "DOUBLE_PRECISION"},
)
identity_attrs_ignore = ()
def __init__(self, *arg, **kw) -> None:
super().__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
"oracle_batch_separator", self.batch_separator
)
def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]:
result = super()._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
return result
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
if rendered_metadata_default is not None:
rendered_metadata_default = re.sub(
r"^\((.+)\)$", r"\1", rendered_metadata_default
)
rendered_metadata_default = re.sub(
r"^\"?'(.+)'\"?$", r"\1", rendered_metadata_default
)
if rendered_inspector_default is not None:
rendered_inspector_default = re.sub(
r"^\((.+)\)$", r"\1", rendered_inspector_default
)
rendered_inspector_default = re.sub(
r"^\"?'(.+)'\"?$", r"\1", rendered_inspector_default
)
rendered_inspector_default = rendered_inspector_default.strip()
return rendered_inspector_default != rendered_metadata_default
def emit_begin(self) -> None:
self._exec("SET TRANSACTION READ WRITE")
def emit_commit(self) -> None:
self._exec("COMMIT")
@compiles(AddColumn, "oracle")
def visit_add_column(
element: AddColumn, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
add_column(compiler, element.column, **kw),
)
@compiles(ColumnNullable, "oracle")
def visit_column_nullable(
element: ColumnNullable, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"NULL" if element.nullable else "NOT NULL",
)
@compiles(ColumnType, "oracle")
def visit_column_type(
element: ColumnType, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"%s" % format_type(compiler, element.type_),
)
@compiles(ColumnName, "oracle")
def visit_column_name(
element: ColumnName, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s RENAME COLUMN %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
format_column_name(compiler, element.newname),
)
@compiles(ColumnDefault, "oracle")
def visit_column_default(
element: ColumnDefault, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DEFAULT NULL",
)
@compiles(ColumnComment, "oracle")
def visit_column_comment(
element: ColumnComment, compiler: OracleDDLCompiler, **kw
) -> str:
ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
comment = compiler.sql_compiler.render_literal_value(
(element.comment if element.comment is not None else ""),
sqltypes.String(),
)
return ddl.format(
table_name=element.table_name,
column_name=element.column_name,
comment=comment,
)
@compiles(RenameTable, "oracle")
def visit_rename_table(
element: RenameTable, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)
def alter_column(compiler: OracleDDLCompiler, name: str) -> str:
return "MODIFY %s" % format_column_name(compiler, name)
def add_column(compiler: OracleDDLCompiler, column: Column[Any], **kw) -> str:
return "ADD %s" % compiler.get_column_specification(column, **kw)
@compiles(IdentityColumnDefault, "oracle")
def visit_identity_column(
element: IdentityColumnDefault, compiler: OracleDDLCompiler, **kw
):
text = "%s %s " % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
)
if element.default is None:
# drop identity
text += "DROP IDENTITY"
return text
else:
text += compiler.visit_identity_column(element.default)
return text

View File

@@ -0,0 +1,848 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import logging
import re
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import Column
from sqlalchemy import literal_column
from sqlalchemy import Numeric
from sqlalchemy import text
from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql import BIGINT
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import INTEGER
from sqlalchemy.schema import CreateIndex
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.functions import FunctionElement
from sqlalchemy.types import NULLTYPE
from .base import alter_column
from .base import alter_table
from .base import AlterColumn
from .base import ColumnComment
from .base import format_column_name
from .base import format_table_name
from .base import format_type
from .base import IdentityColumnDefault
from .base import RenameTable
from .impl import ComparisonResult
from .impl import DefaultImpl
from .. import util
from ..autogenerate import render
from ..operations import ops
from ..operations import schemaobj
from ..operations.base import BatchOperations
from ..operations.base import Operations
from ..util import sqla_compat
from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy import Index
from sqlalchemy import UniqueConstraint
from sqlalchemy.dialects.postgresql.array import ARRAY
from sqlalchemy.dialects.postgresql.base import PGDDLCompiler
from sqlalchemy.dialects.postgresql.hstore import HSTORE
from sqlalchemy.dialects.postgresql.json import JSON
from sqlalchemy.dialects.postgresql.json import JSONB
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
from ..autogenerate.api import AutogenContext
from ..autogenerate.render import _f_name
from ..runtime.migration import MigrationContext
log = logging.getLogger(__name__)
class PostgresqlImpl(DefaultImpl):
__dialect__ = "postgresql"
transactional_ddl = True
type_synonyms = DefaultImpl.type_synonyms + (
{"FLOAT", "DOUBLE PRECISION"},
)
def create_index(self, index: Index, **kw: Any) -> None:
# this likely defaults to None if not present, so get()
# should normally not return the default value. being
# defensive in any case
postgresql_include = index.kwargs.get("postgresql_include", None) or ()
for col in postgresql_include:
if col not in index.table.c: # type: ignore[union-attr]
index.table.append_column( # type: ignore[union-attr]
Column(col, sqltypes.NullType)
)
self._exec(CreateIndex(index, **kw))
def prep_table_for_batch(self, batch_impl, table):
for constraint in table.constraints:
if (
constraint.name is not None
and constraint.name in batch_impl.named_constraints
):
self.drop_constraint(constraint)
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
# don't do defaults for SERIAL columns
if (
metadata_column.primary_key
and metadata_column is metadata_column.table._autoincrement_column
):
return False
conn_col_default = rendered_inspector_default
defaults_equal = conn_col_default == rendered_metadata_default
if defaults_equal:
return False
if None in (
conn_col_default,
rendered_metadata_default,
metadata_column.server_default,
):
return not defaults_equal
metadata_default = metadata_column.server_default.arg
if isinstance(metadata_default, str):
if not isinstance(inspector_column.type, Numeric):
metadata_default = re.sub(r"^'|'$", "", metadata_default)
metadata_default = f"'{metadata_default}'"
metadata_default = literal_column(metadata_default)
# run a real compare against the server
conn = self.connection
assert conn is not None
return not conn.scalar(
sqla_compat._select(
literal_column(conn_col_default) == metadata_default
)
)
def alter_column( # type:ignore[override]
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Union[_ServerDefault, Literal[False]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
autoincrement: Optional[bool] = None,
existing_type: Optional[TypeEngine] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_nullable: Optional[bool] = None,
existing_autoincrement: Optional[bool] = None,
**kw: Any,
) -> None:
using = kw.pop("postgresql_using", None)
if using is not None and type_ is None:
raise util.CommandError(
"postgresql_using must be used with the type_ parameter"
)
if type_ is not None:
self._exec(
PostgresqlColumnType(
table_name,
column_name,
type_,
schema=schema,
using=using,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
)
)
super().alter_column(
table_name,
column_name,
nullable=nullable,
server_default=server_default,
name=name,
schema=schema,
autoincrement=autoincrement,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_autoincrement=existing_autoincrement,
**kw,
)
def autogen_column_reflect(self, inspector, table, column_info):
if column_info.get("default") and isinstance(
column_info["type"], (INTEGER, BIGINT)
):
seq_match = re.match(
r"nextval\('(.+?)'::regclass\)", column_info["default"]
)
if seq_match:
info = sqla_compat._exec_on_inspector(
inspector,
text(
"select c.relname, a.attname "
"from pg_class as c join "
"pg_depend d on d.objid=c.oid and "
"d.classid='pg_class'::regclass and "
"d.refclassid='pg_class'::regclass "
"join pg_class t on t.oid=d.refobjid "
"join pg_attribute a on a.attrelid=t.oid and "
"a.attnum=d.refobjsubid "
"where c.relkind='S' and c.relname=:seqname"
),
seqname=seq_match.group(1),
).first()
if info:
seqname, colname = info
if colname == column_info["name"]:
log.info(
"Detected sequence named '%s' as "
"owned by integer column '%s(%s)', "
"assuming SERIAL and omitting",
seqname,
table.name,
colname,
)
# sequence, and the owner is this column,
# its a SERIAL - whack it!
del column_info["default"]
def correct_for_autogen_constraints(
self,
conn_unique_constraints,
conn_indexes,
metadata_unique_constraints,
metadata_indexes,
):
doubled_constraints = {
index
for index in conn_indexes
if index.info.get("duplicates_constraint")
}
for ix in doubled_constraints:
conn_indexes.remove(ix)
if not sqla_compat.sqla_2:
self._skip_functional_indexes(metadata_indexes, conn_indexes)
# pg behavior regarding modifiers
# | # | compiled sql | returned sql | regexp. group is removed |
# | - | ---------------- | -----------------| ------------------------ |
# | 1 | nulls first | nulls first | - |
# | 2 | nulls last | | (?<! desc)( nulls last)$ |
# | 3 | asc | | ( asc)$ |
# | 4 | asc nulls first | nulls first | ( asc) nulls first$ |
# | 5 | asc nulls last | | ( asc nulls last)$ |
# | 6 | desc | desc | - |
# | 7 | desc nulls first | desc | desc( nulls first)$ |
# | 8 | desc nulls last | desc nulls last | - |
_default_modifiers_re = ( # order of case 2 and 5 matters
re.compile("( asc nulls last)$"), # case 5
re.compile("(?<! desc)( nulls last)$"), # case 2
re.compile("( asc)$"), # case 3
re.compile("( asc) nulls first$"), # case 4
re.compile(" desc( nulls first)$"), # case 7
)
def _cleanup_index_expr(self, index: Index, expr: str) -> str:
expr = expr.lower().replace('"', "").replace("'", "")
if index.table is not None:
# should not be needed, since include_table=False is in compile
expr = expr.replace(f"{index.table.name.lower()}.", "")
if "::" in expr:
# strip :: cast. types can have spaces in them
expr = re.sub(r"(::[\w ]+\w)", "", expr)
while expr and expr[0] == "(" and expr[-1] == ")":
expr = expr[1:-1]
# NOTE: when parsing the connection expression this cleanup could
# be skipped
for rs in self._default_modifiers_re:
if match := rs.search(expr):
start, end = match.span(1)
expr = expr[:start] + expr[end:]
break
while expr and expr[0] == "(" and expr[-1] == ")":
expr = expr[1:-1]
# strip casts
cast_re = re.compile(r"cast\s*\(")
if cast_re.match(expr):
expr = cast_re.sub("", expr)
# remove the as type
expr = re.sub(r"as\s+[^)]+\)", "", expr)
# remove spaces
expr = expr.replace(" ", "")
return expr
def _dialect_options(
self, item: Union[Index, UniqueConstraint]
) -> Tuple[Any, ...]:
# only the positive case is returned by sqlalchemy reflection so
# None and False are threated the same
if item.dialect_kwargs.get("postgresql_nulls_not_distinct"):
return ("nulls_not_distinct",)
return ()
def compare_indexes(
self,
metadata_index: Index,
reflected_index: Index,
) -> ComparisonResult:
msg = []
unique_msg = self._compare_index_unique(
metadata_index, reflected_index
)
if unique_msg:
msg.append(unique_msg)
m_exprs = metadata_index.expressions
r_exprs = reflected_index.expressions
if len(m_exprs) != len(r_exprs):
msg.append(f"expression number {len(r_exprs)} to {len(m_exprs)}")
if msg:
# no point going further, return early
return ComparisonResult.Different(msg)
skip = []
for pos, (m_e, r_e) in enumerate(zip(m_exprs, r_exprs), 1):
m_compile = self._compile_element(m_e)
m_text = self._cleanup_index_expr(metadata_index, m_compile)
# print(f"META ORIG: {m_compile!r} CLEANUP: {m_text!r}")
r_compile = self._compile_element(r_e)
r_text = self._cleanup_index_expr(metadata_index, r_compile)
# print(f"CONN ORIG: {r_compile!r} CLEANUP: {r_text!r}")
if m_text == r_text:
continue # expressions these are equal
elif m_compile.strip().endswith("_ops") and (
" " in m_compile or ")" in m_compile # is an expression
):
skip.append(
f"expression #{pos} {m_compile!r} detected "
"as including operator clause."
)
util.warn(
f"Expression #{pos} {m_compile!r} in index "
f"{reflected_index.name!r} detected to include "
"an operator clause. Expression compare cannot proceed. "
"Please move the operator clause to the "
"``postgresql_ops`` dict to enable proper compare "
"of the index expressions: "
"https://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#operator-classes", # noqa: E501
)
else:
msg.append(f"expression #{pos} {r_compile!r} to {m_compile!r}")
m_options = self._dialect_options(metadata_index)
r_options = self._dialect_options(reflected_index)
if m_options != r_options:
msg.extend(f"options {r_options} to {m_options}")
if msg:
return ComparisonResult.Different(msg)
elif skip:
# if there are other changes detected don't skip the index
return ComparisonResult.Skip(skip)
else:
return ComparisonResult.Equal()
def compare_unique_constraint(
self,
metadata_constraint: UniqueConstraint,
reflected_constraint: UniqueConstraint,
) -> ComparisonResult:
metadata_tup = self._create_metadata_constraint_sig(
metadata_constraint
)
reflected_tup = self._create_reflected_constraint_sig(
reflected_constraint
)
meta_sig = metadata_tup.unnamed
conn_sig = reflected_tup.unnamed
if conn_sig != meta_sig:
return ComparisonResult.Different(
f"expression {conn_sig} to {meta_sig}"
)
metadata_do = self._dialect_options(metadata_tup.const)
conn_do = self._dialect_options(reflected_tup.const)
if metadata_do != conn_do:
return ComparisonResult.Different(
f"expression {conn_do} to {metadata_do}"
)
return ComparisonResult.Equal()
def adjust_reflected_dialect_options(
self, reflected_options: Dict[str, Any], kind: str
) -> Dict[str, Any]:
options: Dict[str, Any]
options = reflected_options.get("dialect_options", {}).copy()
if not options.get("postgresql_include"):
options.pop("postgresql_include", None)
return options
def _compile_element(self, element: Union[ClauseElement, str]) -> str:
if isinstance(element, str):
return element
return element.compile(
dialect=self.dialect,
compile_kwargs={"literal_binds": True, "include_table": False},
).string
def render_ddl_sql_expr(
self,
expr: ClauseElement,
is_server_default: bool = False,
is_index: bool = False,
**kw: Any,
) -> str:
"""Render a SQL expression that is typically a server default,
index expression, etc.
"""
# apply self_group to index expressions;
# see https://github.com/sqlalchemy/sqlalchemy/blob/
# 82fa95cfce070fab401d020c6e6e4a6a96cc2578/
# lib/sqlalchemy/dialects/postgresql/base.py#L2261
if is_index and not isinstance(expr, ColumnClause):
expr = expr.self_group()
return super().render_ddl_sql_expr(
expr, is_server_default=is_server_default, is_index=is_index, **kw
)
def render_type(
self, type_: TypeEngine, autogen_context: AutogenContext
) -> Union[str, Literal[False]]:
mod = type(type_).__module__
if not mod.startswith("sqlalchemy.dialects.postgresql"):
return False
if hasattr(self, "_render_%s_type" % type_.__visit_name__):
meth = getattr(self, "_render_%s_type" % type_.__visit_name__)
return meth(type_, autogen_context)
return False
def _render_HSTORE_type(
self, type_: HSTORE, autogen_context: AutogenContext
) -> str:
return cast(
str,
render._render_type_w_subtype(
type_, autogen_context, "text_type", r"(.+?\(.*text_type=)"
),
)
def _render_ARRAY_type(
self, type_: ARRAY, autogen_context: AutogenContext
) -> str:
return cast(
str,
render._render_type_w_subtype(
type_, autogen_context, "item_type", r"(.+?\()"
),
)
def _render_JSON_type(
self, type_: JSON, autogen_context: AutogenContext
) -> str:
return cast(
str,
render._render_type_w_subtype(
type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
),
)
def _render_JSONB_type(
self, type_: JSONB, autogen_context: AutogenContext
) -> str:
return cast(
str,
render._render_type_w_subtype(
type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
),
)
class PostgresqlColumnType(AlterColumn):
def __init__(
self, name: str, column_name: str, type_: TypeEngine, **kw
) -> None:
using = kw.pop("using", None)
super().__init__(name, column_name, **kw)
self.type_ = sqltypes.to_instance(type_)
self.using = using
@compiles(RenameTable, "postgresql")
def visit_rename_table(
element: RenameTable, compiler: PGDDLCompiler, **kw
) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)
@compiles(PostgresqlColumnType, "postgresql")
def visit_column_type(
element: PostgresqlColumnType, compiler: PGDDLCompiler, **kw
) -> str:
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"TYPE %s" % format_type(compiler, element.type_),
"USING %s" % element.using if element.using else "",
)
@compiles(ColumnComment, "postgresql")
def visit_column_comment(
element: ColumnComment, compiler: PGDDLCompiler, **kw
) -> str:
ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
comment = (
compiler.sql_compiler.render_literal_value(
element.comment, sqltypes.String()
)
if element.comment is not None
else "NULL"
)
return ddl.format(
table_name=format_table_name(
compiler, element.table_name, element.schema
),
column_name=format_column_name(compiler, element.column_name),
comment=comment,
)
@compiles(IdentityColumnDefault, "postgresql")
def visit_identity_column(
element: IdentityColumnDefault, compiler: PGDDLCompiler, **kw
):
text = "%s %s " % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
)
if element.default is None:
# drop identity
text += "DROP IDENTITY"
return text
elif element.existing_server_default is None:
# add identity options
text += "ADD "
text += compiler.visit_identity_column(element.default)
return text
else:
# alter identity
diff, _, _ = element.impl._compare_identity_default(
element.default, element.existing_server_default
)
identity = element.default
for attr in sorted(diff):
if attr == "always":
text += "SET GENERATED %s " % (
"ALWAYS" if identity.always else "BY DEFAULT"
)
else:
text += "SET %s " % compiler.get_identity_options(
sqla_compat.Identity(**{attr: getattr(identity, attr)})
)
return text
@Operations.register_operation("create_exclude_constraint")
@BatchOperations.register_operation(
"create_exclude_constraint", "batch_create_exclude_constraint"
)
@ops.AddConstraintOp.register_add_constraint("exclude_constraint")
class CreateExcludeConstraintOp(ops.AddConstraintOp):
"""Represent a create exclude constraint operation."""
constraint_type = "exclude"
def __init__(
self,
constraint_name: sqla_compat._ConstraintName,
table_name: Union[str, quoted_name],
elements: Union[
Sequence[Tuple[str, str]],
Sequence[Tuple[ColumnClause[Any], str]],
],
where: Optional[Union[ColumnElement[bool], str]] = None,
schema: Optional[str] = None,
_orig_constraint: Optional[ExcludeConstraint] = None,
**kw,
) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
self.elements = elements
self.where = where
self.schema = schema
self._orig_constraint = _orig_constraint
self.kw = kw
@classmethod
def from_constraint( # type:ignore[override]
cls, constraint: ExcludeConstraint
) -> CreateExcludeConstraintOp:
constraint_table = sqla_compat._table_for_constraint(constraint)
return cls(
constraint.name,
constraint_table.name,
[ # type: ignore
(expr, op) for expr, name, op in constraint._render_exprs
],
where=cast("ColumnElement[bool] | None", constraint.where),
schema=constraint_table.schema,
_orig_constraint=constraint,
deferrable=constraint.deferrable,
initially=constraint.initially,
using=constraint.using,
)
def to_constraint(
self, migration_context: Optional[MigrationContext] = None
) -> ExcludeConstraint:
if self._orig_constraint is not None:
return self._orig_constraint
schema_obj = schemaobj.SchemaObjects(migration_context)
t = schema_obj.table(self.table_name, schema=self.schema)
excl = ExcludeConstraint(
*self.elements,
name=self.constraint_name,
where=self.where,
**self.kw,
)
for (
expr,
name,
oper,
) in excl._render_exprs:
t.append_column(Column(name, NULLTYPE))
t.append_constraint(excl)
return excl
@classmethod
def create_exclude_constraint(
cls,
operations: Operations,
constraint_name: str,
table_name: str,
*elements: Any,
**kw: Any,
) -> Optional[Table]:
"""Issue an alter to create an EXCLUDE constraint using the
current migration context.
.. note:: This method is Postgresql specific, and additionally
requires at least SQLAlchemy 1.0.
e.g.::
from alembic import op
op.create_exclude_constraint(
"user_excl",
"user",
("period", "&&"),
("group", "="),
where=("group != 'some group'"),
)
Note that the expressions work the same way as that of
the ``ExcludeConstraint`` object itself; if plain strings are
passed, quoting rules must be applied manually.
:param name: Name of the constraint.
:param table_name: String name of the source table.
:param elements: exclude conditions.
:param where: SQL expression or SQL string with optional WHERE
clause.
:param deferrable: optional bool. If set, emit DEFERRABLE or
NOT DEFERRABLE when issuing DDL for this constraint.
:param initially: optional string. If set, emit INITIALLY <value>
when issuing DDL for this constraint.
:param schema: Optional schema name to operate within.
"""
op = cls(constraint_name, table_name, elements, **kw)
return operations.invoke(op)
@classmethod
def batch_create_exclude_constraint(
cls,
operations: BatchOperations,
constraint_name: str,
*elements: Any,
**kw: Any,
) -> Optional[Table]:
"""Issue a "create exclude constraint" instruction using the
current batch migration context.
.. note:: This method is Postgresql specific, and additionally
requires at least SQLAlchemy 1.0.
.. seealso::
:meth:`.Operations.create_exclude_constraint`
"""
kw["schema"] = operations.impl.schema
op = cls(constraint_name, operations.impl.table_name, elements, **kw)
return operations.invoke(op)
@render.renderers.dispatch_for(CreateExcludeConstraintOp)
def _add_exclude_constraint(
autogen_context: AutogenContext, op: CreateExcludeConstraintOp
) -> str:
return _exclude_constraint(op.to_constraint(), autogen_context, alter=True)
@render._constraint_renderers.dispatch_for(ExcludeConstraint)
def _render_inline_exclude_constraint(
constraint: ExcludeConstraint,
autogen_context: AutogenContext,
namespace_metadata: MetaData,
) -> str:
rendered = render._user_defined_render(
"exclude", constraint, autogen_context
)
if rendered is not False:
return rendered
return _exclude_constraint(constraint, autogen_context, False)
def _postgresql_autogenerate_prefix(autogen_context: AutogenContext) -> str:
imports = autogen_context.imports
if imports is not None:
imports.add("from sqlalchemy.dialects import postgresql")
return "postgresql."
def _exclude_constraint(
constraint: ExcludeConstraint,
autogen_context: AutogenContext,
alter: bool,
) -> str:
opts: List[Tuple[str, Union[quoted_name, str, _f_name, None]]] = []
has_batch = autogen_context._has_batch
if constraint.deferrable:
opts.append(("deferrable", str(constraint.deferrable)))
if constraint.initially:
opts.append(("initially", str(constraint.initially)))
if constraint.using:
opts.append(("using", str(constraint.using)))
if not has_batch and alter and constraint.table.schema:
opts.append(("schema", render._ident(constraint.table.schema)))
if not alter and constraint.name:
opts.append(
("name", render._render_gen_name(autogen_context, constraint.name))
)
def do_expr_where_opts():
args = [
"(%s, %r)"
% (
_render_potential_column(
sqltext, # type:ignore[arg-type]
autogen_context,
),
opstring,
)
for sqltext, name, opstring in constraint._render_exprs
]
if constraint.where is not None:
args.append(
"where=%s"
% render._render_potential_expr(
constraint.where, autogen_context
)
)
args.extend(["%s=%r" % (k, v) for k, v in opts])
return args
if alter:
args = [
repr(render._render_gen_name(autogen_context, constraint.name))
]
if not has_batch:
args += [repr(render._ident(constraint.table.name))]
args.extend(do_expr_where_opts())
return "%(prefix)screate_exclude_constraint(%(args)s)" % {
"prefix": render._alembic_autogenerate_prefix(autogen_context),
"args": ", ".join(args),
}
else:
args = do_expr_where_opts()
return "%(prefix)sExcludeConstraint(%(args)s)" % {
"prefix": _postgresql_autogenerate_prefix(autogen_context),
"args": ", ".join(args),
}
def _render_potential_column(
value: Union[
ColumnClause[Any], Column[Any], TextClause, FunctionElement[Any]
],
autogen_context: AutogenContext,
) -> str:
if isinstance(value, ColumnClause):
if value.is_literal:
# like literal_column("int8range(from, to)") in ExcludeConstraint
template = "%(prefix)sliteral_column(%(name)r)"
else:
template = "%(prefix)scolumn(%(name)r)"
return template % {
"prefix": render._sqlalchemy_autogenerate_prefix(autogen_context),
"name": value.name,
}
else:
return render._render_potential_expr(
value,
autogen_context,
wrap_in_text=isinstance(value, (TextClause, FunctionElement)),
)

View File

@@ -0,0 +1,225 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import re
from typing import Any
from typing import Dict
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import cast
from sqlalchemy import JSON
from sqlalchemy import schema
from sqlalchemy import sql
from .base import alter_table
from .base import format_table_name
from .base import RenameTable
from .impl import DefaultImpl
from .. import util
from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.compiler import DDLCompiler
from sqlalchemy.sql.elements import Cast
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.type_api import TypeEngine
from ..operations.batch import BatchOperationsImpl
class SQLiteImpl(DefaultImpl):
__dialect__ = "sqlite"
transactional_ddl = False
"""SQLite supports transactional DDL, but pysqlite does not:
see: http://bugs.python.org/issue10740
"""
def requires_recreate_in_batch(
self, batch_op: BatchOperationsImpl
) -> bool:
"""Return True if the given :class:`.BatchOperationsImpl`
would need the table to be recreated and copied in order to
proceed.
Normally, only returns True on SQLite when operations other
than add_column are present.
"""
for op in batch_op.batch:
if op[0] == "add_column":
col = op[1][1]
if isinstance(
col.server_default, schema.DefaultClause
) and isinstance(col.server_default.arg, sql.ClauseElement):
return True
elif (
isinstance(col.server_default, util.sqla_compat.Computed)
and col.server_default.persisted
):
return True
elif op[0] not in ("create_index", "drop_index"):
return True
else:
return False
def add_constraint(self, const: Constraint):
# attempt to distinguish between an
# auto-gen constraint and an explicit one
if const._create_rule is None:
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect. "
"Please refer to the batch mode feature which allows for "
"SQLite migrations using a copy-and-move strategy."
)
elif const._create_rule(self):
util.warn(
"Skipping unsupported ALTER for "
"creation of implicit constraint. "
"Please refer to the batch mode feature which allows for "
"SQLite migrations using a copy-and-move strategy."
)
def drop_constraint(self, const: Constraint):
if const._create_rule is None:
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect. "
"Please refer to the batch mode feature which allows for "
"SQLite migrations using a copy-and-move strategy."
)
def compare_server_default(
self,
inspector_column: Column[Any],
metadata_column: Column[Any],
rendered_metadata_default: Optional[str],
rendered_inspector_default: Optional[str],
) -> bool:
if rendered_metadata_default is not None:
rendered_metadata_default = re.sub(
r"^\((.+)\)$", r"\1", rendered_metadata_default
)
rendered_metadata_default = re.sub(
r"^\"?'(.+)'\"?$", r"\1", rendered_metadata_default
)
if rendered_inspector_default is not None:
rendered_inspector_default = re.sub(
r"^\((.+)\)$", r"\1", rendered_inspector_default
)
rendered_inspector_default = re.sub(
r"^\"?'(.+)'\"?$", r"\1", rendered_inspector_default
)
return rendered_inspector_default != rendered_metadata_default
def _guess_if_default_is_unparenthesized_sql_expr(
self, expr: Optional[str]
) -> bool:
"""Determine if a server default is a SQL expression or a constant.
There are too many assertions that expect server defaults to round-trip
identically without parenthesis added so we will add parens only in
very specific cases.
"""
if not expr:
return False
elif re.match(r"^[0-9\.]$", expr):
return False
elif re.match(r"^'.+'$", expr):
return False
elif re.match(r"^\(.+\)$", expr):
return False
else:
return True
def autogen_column_reflect(
self,
inspector: Inspector,
table: Table,
column_info: Dict[str, Any],
) -> None:
# SQLite expression defaults require parenthesis when sent
# as DDL
if self._guess_if_default_is_unparenthesized_sql_expr(
column_info.get("default", None)
):
column_info["default"] = "(%s)" % (column_info["default"],)
def render_ddl_sql_expr(
self, expr: ClauseElement, is_server_default: bool = False, **kw
) -> str:
# SQLite expression defaults require parenthesis when sent
# as DDL
str_expr = super().render_ddl_sql_expr(
expr, is_server_default=is_server_default, **kw
)
if (
is_server_default
and self._guess_if_default_is_unparenthesized_sql_expr(str_expr)
):
str_expr = "(%s)" % (str_expr,)
return str_expr
def cast_for_batch_migrate(
self,
existing: Column[Any],
existing_transfer: Dict[str, Union[TypeEngine, Cast]],
new_type: TypeEngine,
) -> None:
if (
existing.type._type_affinity is not new_type._type_affinity
and not isinstance(new_type, JSON)
):
existing_transfer["expr"] = cast(
existing_transfer["expr"], new_type
)
def correct_for_autogen_constraints(
self,
conn_unique_constraints,
conn_indexes,
metadata_unique_constraints,
metadata_indexes,
):
self._skip_functional_indexes(metadata_indexes, conn_indexes)
@compiles(RenameTable, "sqlite")
def visit_rename_table(
element: RenameTable, compiler: DDLCompiler, **kw
) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)
# @compiles(AddColumn, 'sqlite')
# def visit_add_column(element, compiler, **kw):
# return "%s %s" % (
# alter_table(compiler, element.table_name, element.schema),
# add_column(compiler, element.column, **kw)
# )
# def add_column(compiler, column, **kw):
# text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
# need to modify SQLAlchemy so that the CHECK associated with a Boolean
# or Enum gets placed as part of the column constraints, not the Table
# see ticket 98
# for const in column.constraints:
# text += compiler.process(AddConstraint(const))
# return text

View File

@@ -0,0 +1 @@
from .runtime.environment import * # noqa

View File

@@ -0,0 +1 @@
from .runtime.migration import * # noqa

View File

@@ -0,0 +1,5 @@
from .operations.base import Operations
# create proxy functions for
# each method on the Operations class.
Operations.create_module_class_proxy(globals(), locals())

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,15 @@
from . import toimpl
from .base import AbstractOperations
from .base import BatchOperations
from .base import Operations
from .ops import MigrateOperation
from .ops import MigrationScript
__all__ = [
"AbstractOperations",
"Operations",
"BatchOperations",
"MigrateOperation",
"MigrationScript",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,717 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import CheckConstraint
from sqlalchemy import Column
from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import Index
from sqlalchemy import MetaData
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy import schema as sql_schema
from sqlalchemy import Table
from sqlalchemy import types as sqltypes
from sqlalchemy.sql.schema import SchemaEventTarget
from sqlalchemy.util import OrderedDict
from sqlalchemy.util import topological
from ..util import exc
from ..util.sqla_compat import _columns_for_constraint
from ..util.sqla_compat import _copy
from ..util.sqla_compat import _copy_expression
from ..util.sqla_compat import _ensure_scope_for_ddl
from ..util.sqla_compat import _fk_is_self_referential
from ..util.sqla_compat import _idx_table_bound_expressions
from ..util.sqla_compat import _insert_inline
from ..util.sqla_compat import _is_type_bound
from ..util.sqla_compat import _remove_column_from_collection
from ..util.sqla_compat import _resolve_for_variant
from ..util.sqla_compat import _select
from ..util.sqla_compat import constraint_name_defined
from ..util.sqla_compat import constraint_name_string
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy.engine import Dialect
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.functions import Function
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.type_api import TypeEngine
from ..ddl.impl import DefaultImpl
class BatchOperationsImpl:
def __init__(
self,
operations,
table_name,
schema,
recreate,
copy_from,
table_args,
table_kwargs,
reflect_args,
reflect_kwargs,
naming_convention,
partial_reordering,
):
self.operations = operations
self.table_name = table_name
self.schema = schema
if recreate not in ("auto", "always", "never"):
raise ValueError(
"recreate may be one of 'auto', 'always', or 'never'."
)
self.recreate = recreate
self.copy_from = copy_from
self.table_args = table_args
self.table_kwargs = dict(table_kwargs)
self.reflect_args = reflect_args
self.reflect_kwargs = dict(reflect_kwargs)
self.reflect_kwargs.setdefault(
"listeners", list(self.reflect_kwargs.get("listeners", ()))
)
self.reflect_kwargs["listeners"].append(
("column_reflect", operations.impl.autogen_column_reflect)
)
self.naming_convention = naming_convention
self.partial_reordering = partial_reordering
self.batch = []
@property
def dialect(self) -> Dialect:
return self.operations.impl.dialect
@property
def impl(self) -> DefaultImpl:
return self.operations.impl
def _should_recreate(self) -> bool:
if self.recreate == "auto":
return self.operations.impl.requires_recreate_in_batch(self)
elif self.recreate == "always":
return True
else:
return False
def flush(self) -> None:
should_recreate = self._should_recreate()
with _ensure_scope_for_ddl(self.impl.connection):
if not should_recreate:
for opname, arg, kw in self.batch:
fn = getattr(self.operations.impl, opname)
fn(*arg, **kw)
else:
if self.naming_convention:
m1 = MetaData(naming_convention=self.naming_convention)
else:
m1 = MetaData()
if self.copy_from is not None:
existing_table = self.copy_from
reflected = False
else:
if self.operations.migration_context.as_sql:
raise exc.CommandError(
f"This operation cannot proceed in --sql mode; "
f"batch mode with dialect "
f"{self.operations.migration_context.dialect.name} " # noqa: E501
f"requires a live database connection with which "
f'to reflect the table "{self.table_name}". '
f"To generate a batch SQL migration script using "
"table "
'"move and copy", a complete Table object '
f'should be passed to the "copy_from" argument '
"of the batch_alter_table() method so that table "
"reflection can be skipped."
)
existing_table = Table(
self.table_name,
m1,
schema=self.schema,
autoload_with=self.operations.get_bind(),
*self.reflect_args,
**self.reflect_kwargs,
)
reflected = True
batch_impl = ApplyBatchImpl(
self.impl,
existing_table,
self.table_args,
self.table_kwargs,
reflected,
partial_reordering=self.partial_reordering,
)
for opname, arg, kw in self.batch:
fn = getattr(batch_impl, opname)
fn(*arg, **kw)
batch_impl._create(self.impl)
def alter_column(self, *arg, **kw) -> None:
self.batch.append(("alter_column", arg, kw))
def add_column(self, *arg, **kw) -> None:
if (
"insert_before" in kw or "insert_after" in kw
) and not self._should_recreate():
raise exc.CommandError(
"Can't specify insert_before or insert_after when using "
"ALTER; please specify recreate='always'"
)
self.batch.append(("add_column", arg, kw))
def drop_column(self, *arg, **kw) -> None:
self.batch.append(("drop_column", arg, kw))
def add_constraint(self, const: Constraint) -> None:
self.batch.append(("add_constraint", (const,), {}))
def drop_constraint(self, const: Constraint) -> None:
self.batch.append(("drop_constraint", (const,), {}))
def rename_table(self, *arg, **kw):
self.batch.append(("rename_table", arg, kw))
def create_index(self, idx: Index, **kw: Any) -> None:
self.batch.append(("create_index", (idx,), kw))
def drop_index(self, idx: Index, **kw: Any) -> None:
self.batch.append(("drop_index", (idx,), kw))
def create_table_comment(self, table):
self.batch.append(("create_table_comment", (table,), {}))
def drop_table_comment(self, table):
self.batch.append(("drop_table_comment", (table,), {}))
def create_table(self, table):
raise NotImplementedError("Can't create table in batch mode")
def drop_table(self, table):
raise NotImplementedError("Can't drop table in batch mode")
def create_column_comment(self, column):
self.batch.append(("create_column_comment", (column,), {}))
class ApplyBatchImpl:
def __init__(
self,
impl: DefaultImpl,
table: Table,
table_args: tuple,
table_kwargs: Dict[str, Any],
reflected: bool,
partial_reordering: tuple = (),
) -> None:
self.impl = impl
self.table = table # this is a Table object
self.table_args = table_args
self.table_kwargs = table_kwargs
self.temp_table_name = self._calc_temp_name(table.name)
self.new_table: Optional[Table] = None
self.partial_reordering = partial_reordering # tuple of tuples
self.add_col_ordering: Tuple[
Tuple[str, str], ...
] = () # tuple of tuples
self.column_transfers = OrderedDict(
(c.name, {"expr": c}) for c in self.table.c
)
self.existing_ordering = list(self.column_transfers)
self.reflected = reflected
self._grab_table_elements()
@classmethod
def _calc_temp_name(cls, tablename: Union[quoted_name, str]) -> str:
return ("_alembic_tmp_%s" % tablename)[0:50]
def _grab_table_elements(self) -> None:
schema = self.table.schema
self.columns: Dict[str, Column[Any]] = OrderedDict()
for c in self.table.c:
c_copy = _copy(c, schema=schema)
c_copy.unique = c_copy.index = False
# ensure that the type object was copied,
# as we may need to modify it in-place
if isinstance(c.type, SchemaEventTarget):
assert c_copy.type is not c.type
self.columns[c.name] = c_copy
self.named_constraints: Dict[str, Constraint] = {}
self.unnamed_constraints = []
self.col_named_constraints = {}
self.indexes: Dict[str, Index] = {}
self.new_indexes: Dict[str, Index] = {}
for const in self.table.constraints:
if _is_type_bound(const):
continue
elif (
self.reflected
and isinstance(const, CheckConstraint)
and not const.name
):
# TODO: we are skipping unnamed reflected CheckConstraint
# because
# we have no way to determine _is_type_bound() for these.
pass
elif constraint_name_string(const.name):
self.named_constraints[const.name] = const
else:
self.unnamed_constraints.append(const)
if not self.reflected:
for col in self.table.c:
for const in col.constraints:
if const.name:
self.col_named_constraints[const.name] = (col, const)
for idx in self.table.indexes:
self.indexes[idx.name] = idx # type: ignore[index]
for k in self.table.kwargs:
self.table_kwargs.setdefault(k, self.table.kwargs[k])
def _adjust_self_columns_for_partial_reordering(self) -> None:
pairs = set()
col_by_idx = list(self.columns)
if self.partial_reordering:
for tuple_ in self.partial_reordering:
for index, elem in enumerate(tuple_):
if index > 0:
pairs.add((tuple_[index - 1], elem))
else:
for index, elem in enumerate(self.existing_ordering):
if index > 0:
pairs.add((col_by_idx[index - 1], elem))
pairs.update(self.add_col_ordering)
# this can happen if some columns were dropped and not removed
# from existing_ordering. this should be prevented already, but
# conservatively making sure this didn't happen
pairs_list = [p for p in pairs if p[0] != p[1]]
sorted_ = list(
topological.sort(pairs_list, col_by_idx, deterministic_order=True)
)
self.columns = OrderedDict((k, self.columns[k]) for k in sorted_)
self.column_transfers = OrderedDict(
(k, self.column_transfers[k]) for k in sorted_
)
def _transfer_elements_to_new_table(self) -> None:
assert self.new_table is None, "Can only create new table once"
m = MetaData()
schema = self.table.schema
if self.partial_reordering or self.add_col_ordering:
self._adjust_self_columns_for_partial_reordering()
self.new_table = new_table = Table(
self.temp_table_name,
m,
*(list(self.columns.values()) + list(self.table_args)),
schema=schema,
**self.table_kwargs,
)
for const in (
list(self.named_constraints.values()) + self.unnamed_constraints
):
const_columns = {c.key for c in _columns_for_constraint(const)}
if not const_columns.issubset(self.column_transfers):
continue
const_copy: Constraint
if isinstance(const, ForeignKeyConstraint):
if _fk_is_self_referential(const):
# for self-referential constraint, refer to the
# *original* table name, and not _alembic_batch_temp.
# This is consistent with how we're handling
# FK constraints from other tables; we assume SQLite
# no foreign keys just keeps the names unchanged, so
# when we rename back, they match again.
const_copy = _copy(
const, schema=schema, target_table=self.table
)
else:
# "target_table" for ForeignKeyConstraint.copy() is
# only used if the FK is detected as being
# self-referential, which we are handling above.
const_copy = _copy(const, schema=schema)
else:
const_copy = _copy(
const, schema=schema, target_table=new_table
)
if isinstance(const, ForeignKeyConstraint):
self._setup_referent(m, const)
new_table.append_constraint(const_copy)
def _gather_indexes_from_both_tables(self) -> List[Index]:
assert self.new_table is not None
idx: List[Index] = []
for idx_existing in self.indexes.values():
# this is a lift-and-move from Table.to_metadata
if idx_existing._column_flag:
continue
idx_copy = Index(
idx_existing.name,
unique=idx_existing.unique,
*[
_copy_expression(expr, self.new_table)
for expr in _idx_table_bound_expressions(idx_existing)
],
_table=self.new_table,
**idx_existing.kwargs,
)
idx.append(idx_copy)
for index in self.new_indexes.values():
idx.append(
Index(
index.name,
unique=index.unique,
*[self.new_table.c[col] for col in index.columns.keys()],
**index.kwargs,
)
)
return idx
def _setup_referent(
self, metadata: MetaData, constraint: ForeignKeyConstraint
) -> None:
spec = constraint.elements[0]._get_colspec()
parts = spec.split(".")
tname = parts[-2]
if len(parts) == 3:
referent_schema = parts[0]
else:
referent_schema = None
if tname != self.temp_table_name:
key = sql_schema._get_table_key(tname, referent_schema)
def colspec(elem: Any):
return elem._get_colspec()
if key in metadata.tables:
t = metadata.tables[key]
for elem in constraint.elements:
colname = colspec(elem).split(".")[-1]
if colname not in t.c:
t.append_column(Column(colname, sqltypes.NULLTYPE))
else:
Table(
tname,
metadata,
*[
Column(n, sqltypes.NULLTYPE)
for n in [
colspec(elem).split(".")[-1]
for elem in constraint.elements
]
],
schema=referent_schema,
)
def _create(self, op_impl: DefaultImpl) -> None:
self._transfer_elements_to_new_table()
op_impl.prep_table_for_batch(self, self.table)
assert self.new_table is not None
op_impl.create_table(self.new_table)
try:
op_impl._exec(
_insert_inline(self.new_table).from_select(
list(
k
for k, transfer in self.column_transfers.items()
if "expr" in transfer
),
_select(
*[
transfer["expr"]
for transfer in self.column_transfers.values()
if "expr" in transfer
]
),
)
)
op_impl.drop_table(self.table)
except:
op_impl.drop_table(self.new_table)
raise
else:
op_impl.rename_table(
self.temp_table_name, self.table.name, schema=self.table.schema
)
self.new_table.name = self.table.name
try:
for idx in self._gather_indexes_from_both_tables():
op_impl.create_index(idx)
finally:
self.new_table.name = self.temp_table_name
def alter_column(
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Optional[Union[Function[Any], str, bool]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
autoincrement: Optional[Union[bool, Literal["auto"]]] = None,
comment: Union[str, Literal[False]] = False,
**kw,
) -> None:
existing = self.columns[column_name]
existing_transfer: Dict[str, Any] = self.column_transfers[column_name]
if name is not None and name != column_name:
# note that we don't change '.key' - we keep referring
# to the renamed column by its old key in _create(). neat!
existing.name = name
existing_transfer["name"] = name
existing_type = kw.get("existing_type", None)
if existing_type:
resolved_existing_type = _resolve_for_variant(
kw["existing_type"], self.impl.dialect
)
# pop named constraints for Boolean/Enum for rename
if (
isinstance(resolved_existing_type, SchemaEventTarget)
and resolved_existing_type.name # type:ignore[attr-defined] # noqa E501
):
self.named_constraints.pop(
resolved_existing_type.name, # type:ignore[attr-defined] # noqa E501
None,
)
if type_ is not None:
type_ = sqltypes.to_instance(type_)
# old type is being discarded so turn off eventing
# rules. Alternatively we can
# erase the events set up by this type, but this is simpler.
# we also ignore the drop_constraint that will come here from
# Operations.implementation_for(alter_column)
if isinstance(existing.type, SchemaEventTarget):
existing.type._create_events = ( # type:ignore[attr-defined]
existing.type.create_constraint # type:ignore[attr-defined] # noqa
) = False
self.impl.cast_for_batch_migrate(
existing, existing_transfer, type_
)
existing.type = type_
# we *dont* however set events for the new type, because
# alter_column is invoked from
# Operations.implementation_for(alter_column) which already
# will emit an add_constraint()
if nullable is not None:
existing.nullable = nullable
if server_default is not False:
if server_default is None:
existing.server_default = None
else:
sql_schema.DefaultClause(
server_default # type: ignore[arg-type]
)._set_parent(existing)
if autoincrement is not None:
existing.autoincrement = bool(autoincrement)
if comment is not False:
existing.comment = comment
def _setup_dependencies_for_add_column(
self,
colname: str,
insert_before: Optional[str],
insert_after: Optional[str],
) -> None:
index_cols = self.existing_ordering
col_indexes = {name: i for i, name in enumerate(index_cols)}
if not self.partial_reordering:
if insert_after:
if not insert_before:
if insert_after in col_indexes:
# insert after an existing column
idx = col_indexes[insert_after] + 1
if idx < len(index_cols):
insert_before = index_cols[idx]
else:
# insert after a column that is also new
insert_before = dict(self.add_col_ordering)[
insert_after
]
if insert_before:
if not insert_after:
if insert_before in col_indexes:
# insert before an existing column
idx = col_indexes[insert_before] - 1
if idx >= 0:
insert_after = index_cols[idx]
else:
# insert before a column that is also new
insert_after = {
b: a for a, b in self.add_col_ordering
}[insert_before]
if insert_before:
self.add_col_ordering += ((colname, insert_before),)
if insert_after:
self.add_col_ordering += ((insert_after, colname),)
if (
not self.partial_reordering
and not insert_before
and not insert_after
and col_indexes
):
self.add_col_ordering += ((index_cols[-1], colname),)
def add_column(
self,
table_name: str,
column: Column[Any],
insert_before: Optional[str] = None,
insert_after: Optional[str] = None,
**kw,
) -> None:
self._setup_dependencies_for_add_column(
column.name, insert_before, insert_after
)
# we copy the column because operations.add_column()
# gives us a Column that is part of a Table already.
self.columns[column.name] = _copy(column, schema=self.table.schema)
self.column_transfers[column.name] = {}
def drop_column(
self,
table_name: str,
column: Union[ColumnClause[Any], Column[Any]],
**kw,
) -> None:
if column.name in self.table.primary_key.columns:
_remove_column_from_collection(
self.table.primary_key.columns, column
)
del self.columns[column.name]
del self.column_transfers[column.name]
self.existing_ordering.remove(column.name)
# pop named constraints for Boolean/Enum for rename
if (
"existing_type" in kw
and isinstance(kw["existing_type"], SchemaEventTarget)
and kw["existing_type"].name # type:ignore[attr-defined]
):
self.named_constraints.pop(
kw["existing_type"].name, None # type:ignore[attr-defined]
)
def create_column_comment(self, column):
"""the batch table creation function will issue create_column_comment
on the real "impl" as part of the create table process.
That is, the Column object will have the comment on it already,
so when it is received by add_column() it will be a normal part of
the CREATE TABLE and doesn't need an extra step here.
"""
def create_table_comment(self, table):
"""the batch table creation function will issue create_table_comment
on the real "impl" as part of the create table process.
"""
def drop_table_comment(self, table):
"""the batch table creation function will issue drop_table_comment
on the real "impl" as part of the create table process.
"""
def add_constraint(self, const: Constraint) -> None:
if not constraint_name_defined(const.name):
raise ValueError("Constraint must have a name")
if isinstance(const, sql_schema.PrimaryKeyConstraint):
if self.table.primary_key in self.unnamed_constraints:
self.unnamed_constraints.remove(self.table.primary_key)
if constraint_name_string(const.name):
self.named_constraints[const.name] = const
else:
self.unnamed_constraints.append(const)
def drop_constraint(self, const: Constraint) -> None:
if not const.name:
raise ValueError("Constraint must have a name")
try:
if const.name in self.col_named_constraints:
col, const = self.col_named_constraints.pop(const.name)
for col_const in list(self.columns[col.name].constraints):
if col_const.name == const.name:
self.columns[col.name].constraints.remove(col_const)
elif constraint_name_string(const.name):
const = self.named_constraints.pop(const.name)
elif const in self.unnamed_constraints:
self.unnamed_constraints.remove(const)
except KeyError:
if _is_type_bound(const):
# type-bound constraints are only included in the new
# table via their type object in any case, so ignore the
# drop_constraint() that comes here via the
# Operations.implementation_for(alter_column)
return
raise ValueError("No such constraint: '%s'" % const.name)
else:
if isinstance(const, PrimaryKeyConstraint):
for col in const.columns:
self.columns[col.name].primary_key = False
def create_index(self, idx: Index) -> None:
self.new_indexes[idx.name] = idx # type: ignore[index]
def drop_index(self, idx: Index) -> None:
try:
del self.indexes[idx.name] # type: ignore[arg-type]
except KeyError:
raise ValueError("No such index: '%s'" % idx.name)
def rename_table(self, *arg, **kw):
raise NotImplementedError("TODO")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,288 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import schema as sa_schema
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import Index
from sqlalchemy.types import Integer
from sqlalchemy.types import NULLTYPE
from .. import util
from ..util import sqla_compat
if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.schema import CheckConstraint
from sqlalchemy.sql.schema import ForeignKey
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import PrimaryKeyConstraint
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.schema import UniqueConstraint
from sqlalchemy.sql.type_api import TypeEngine
from ..runtime.migration import MigrationContext
class SchemaObjects:
def __init__(
self, migration_context: Optional[MigrationContext] = None
) -> None:
self.migration_context = migration_context
def primary_key_constraint(
self,
name: Optional[sqla_compat._ConstraintNameDefined],
table_name: str,
cols: Sequence[str],
schema: Optional[str] = None,
**dialect_kw,
) -> PrimaryKeyConstraint:
m = self.metadata()
columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
t = sa_schema.Table(table_name, m, *columns, schema=schema)
# SQLAlchemy primary key constraint name arg is wrongly typed on
# the SQLAlchemy side through 2.0.5 at least
p = sa_schema.PrimaryKeyConstraint(
*[t.c[n] for n in cols], name=name, **dialect_kw # type: ignore
)
return p
def foreign_key_constraint(
self,
name: Optional[sqla_compat._ConstraintNameDefined],
source: str,
referent: str,
local_cols: List[str],
remote_cols: List[str],
onupdate: Optional[str] = None,
ondelete: Optional[str] = None,
deferrable: Optional[bool] = None,
source_schema: Optional[str] = None,
referent_schema: Optional[str] = None,
initially: Optional[str] = None,
match: Optional[str] = None,
**dialect_kw,
) -> ForeignKeyConstraint:
m = self.metadata()
if source == referent and source_schema == referent_schema:
t1_cols = local_cols + remote_cols
else:
t1_cols = local_cols
sa_schema.Table(
referent,
m,
*[sa_schema.Column(n, NULLTYPE) for n in remote_cols],
schema=referent_schema,
)
t1 = sa_schema.Table(
source,
m,
*[
sa_schema.Column(n, NULLTYPE)
for n in util.unique_list(t1_cols)
],
schema=source_schema,
)
tname = (
"%s.%s" % (referent_schema, referent)
if referent_schema
else referent
)
dialect_kw["match"] = match
f = sa_schema.ForeignKeyConstraint(
local_cols,
["%s.%s" % (tname, n) for n in remote_cols],
name=name,
onupdate=onupdate,
ondelete=ondelete,
deferrable=deferrable,
initially=initially,
**dialect_kw,
)
t1.append_constraint(f)
return f
def unique_constraint(
self,
name: Optional[sqla_compat._ConstraintNameDefined],
source: str,
local_cols: Sequence[str],
schema: Optional[str] = None,
**kw,
) -> UniqueConstraint:
t = sa_schema.Table(
source,
self.metadata(),
*[sa_schema.Column(n, NULLTYPE) for n in local_cols],
schema=schema,
)
kw["name"] = name
uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw)
# TODO: need event tests to ensure the event
# is fired off here
t.append_constraint(uq)
return uq
def check_constraint(
self,
name: Optional[sqla_compat._ConstraintNameDefined],
source: str,
condition: Union[str, TextClause, ColumnElement[Any]],
schema: Optional[str] = None,
**kw,
) -> Union[CheckConstraint]:
t = sa_schema.Table(
source,
self.metadata(),
sa_schema.Column("x", Integer),
schema=schema,
)
ck = sa_schema.CheckConstraint(condition, name=name, **kw)
t.append_constraint(ck)
return ck
def generic_constraint(
self,
name: Optional[sqla_compat._ConstraintNameDefined],
table_name: str,
type_: Optional[str],
schema: Optional[str] = None,
**kw,
) -> Any:
t = self.table(table_name, schema=schema)
types: Dict[Optional[str], Any] = {
"foreignkey": lambda name: sa_schema.ForeignKeyConstraint(
[], [], name=name
),
"primary": sa_schema.PrimaryKeyConstraint,
"unique": sa_schema.UniqueConstraint,
"check": lambda name: sa_schema.CheckConstraint("", name=name),
None: sa_schema.Constraint,
}
try:
const = types[type_]
except KeyError as ke:
raise TypeError(
"'type' can be one of %s"
% ", ".join(sorted(repr(x) for x in types))
) from ke
else:
const = const(name=name)
t.append_constraint(const)
return const
def metadata(self) -> MetaData:
kw = {}
if (
self.migration_context is not None
and "target_metadata" in self.migration_context.opts
):
mt = self.migration_context.opts["target_metadata"]
if hasattr(mt, "naming_convention"):
kw["naming_convention"] = mt.naming_convention
return sa_schema.MetaData(**kw)
def table(self, name: str, *columns, **kw) -> Table:
m = self.metadata()
cols = [
sqla_compat._copy(c) if c.table is not None else c
for c in columns
if isinstance(c, Column)
]
# these flags have already added their UniqueConstraint /
# Index objects to the table, so flip them off here.
# SQLAlchemy tometadata() avoids this instead by preserving the
# flags and skipping the constraints that have _type_bound on them,
# but for a migration we'd rather list out the constraints
# explicitly.
_constraints_included = kw.pop("_constraints_included", False)
if _constraints_included:
for c in cols:
c.unique = c.index = False
t = sa_schema.Table(name, m, *cols, **kw)
constraints = [
sqla_compat._copy(elem, target_table=t)
if getattr(elem, "parent", None) is not t
and getattr(elem, "parent", None) is not None
else elem
for elem in columns
if isinstance(elem, (Constraint, Index))
]
for const in constraints:
t.append_constraint(const)
for f in t.foreign_keys:
self._ensure_table_for_fk(m, f)
return t
def column(self, name: str, type_: TypeEngine, **kw) -> Column:
return sa_schema.Column(name, type_, **kw)
def index(
self,
name: Optional[str],
tablename: Optional[str],
columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
schema: Optional[str] = None,
**kw,
) -> Index:
t = sa_schema.Table(
tablename or "no_table",
self.metadata(),
schema=schema,
)
kw["_table"] = t
idx = sa_schema.Index(
name,
*[util.sqla_compat._textual_index_column(t, n) for n in columns],
**kw,
)
return idx
def _parse_table_key(self, table_key: str) -> Tuple[Optional[str], str]:
if "." in table_key:
tokens = table_key.split(".")
sname: Optional[str] = ".".join(tokens[0:-1])
tname = tokens[-1]
else:
tname = table_key
sname = None
return (sname, tname)
def _ensure_table_for_fk(self, metadata: MetaData, fk: ForeignKey) -> None:
"""create a placeholder Table object for the referent of a
ForeignKey.
"""
if isinstance(fk._colspec, str):
table_key, cname = fk._colspec.rsplit(".", 1)
sname, tname = self._parse_table_key(table_key)
if table_key not in metadata.tables:
rel_t = sa_schema.Table(tname, metadata, schema=sname)
else:
rel_t = metadata.tables[table_key]
if cname not in rel_t.c:
rel_t.append_column(sa_schema.Column(cname, NULLTYPE))

View File

@@ -0,0 +1,226 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from typing import TYPE_CHECKING
from sqlalchemy import schema as sa_schema
from . import ops
from .base import Operations
from ..util.sqla_compat import _copy
from ..util.sqla_compat import sqla_14
if TYPE_CHECKING:
from sqlalchemy.sql.schema import Table
@Operations.implementation_for(ops.AlterColumnOp)
def alter_column(
operations: "Operations", operation: "ops.AlterColumnOp"
) -> None:
compiler = operations.impl.dialect.statement_compiler(
operations.impl.dialect, None
)
existing_type = operation.existing_type
existing_nullable = operation.existing_nullable
existing_server_default = operation.existing_server_default
type_ = operation.modify_type
column_name = operation.column_name
table_name = operation.table_name
schema = operation.schema
server_default = operation.modify_server_default
new_column_name = operation.modify_name
nullable = operation.modify_nullable
comment = operation.modify_comment
existing_comment = operation.existing_comment
def _count_constraint(constraint):
return not isinstance(constraint, sa_schema.PrimaryKeyConstraint) and (
not constraint._create_rule or constraint._create_rule(compiler)
)
if existing_type and type_:
t = operations.schema_obj.table(
table_name,
sa_schema.Column(column_name, existing_type),
schema=schema,
)
for constraint in t.constraints:
if _count_constraint(constraint):
operations.impl.drop_constraint(constraint)
operations.impl.alter_column(
table_name,
column_name,
nullable=nullable,
server_default=server_default,
name=new_column_name,
type_=type_,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
comment=comment,
existing_comment=existing_comment,
**operation.kw,
)
if type_:
t = operations.schema_obj.table(
table_name,
operations.schema_obj.column(column_name, type_),
schema=schema,
)
for constraint in t.constraints:
if _count_constraint(constraint):
operations.impl.add_constraint(constraint)
@Operations.implementation_for(ops.DropTableOp)
def drop_table(operations: "Operations", operation: "ops.DropTableOp") -> None:
operations.impl.drop_table(
operation.to_table(operations.migration_context)
)
@Operations.implementation_for(ops.DropColumnOp)
def drop_column(
operations: "Operations", operation: "ops.DropColumnOp"
) -> None:
column = operation.to_column(operations.migration_context)
operations.impl.drop_column(
operation.table_name, column, schema=operation.schema, **operation.kw
)
@Operations.implementation_for(ops.CreateIndexOp)
def create_index(
operations: "Operations", operation: "ops.CreateIndexOp"
) -> None:
idx = operation.to_index(operations.migration_context)
kw = {}
if operation.if_not_exists is not None:
if not sqla_14:
raise NotImplementedError("SQLAlchemy 1.4+ required")
kw["if_not_exists"] = operation.if_not_exists
operations.impl.create_index(idx, **kw)
@Operations.implementation_for(ops.DropIndexOp)
def drop_index(operations: "Operations", operation: "ops.DropIndexOp") -> None:
kw = {}
if operation.if_exists is not None:
if not sqla_14:
raise NotImplementedError("SQLAlchemy 1.4+ required")
kw["if_exists"] = operation.if_exists
operations.impl.drop_index(
operation.to_index(operations.migration_context),
**kw,
)
@Operations.implementation_for(ops.CreateTableOp)
def create_table(
operations: "Operations", operation: "ops.CreateTableOp"
) -> "Table":
table = operation.to_table(operations.migration_context)
operations.impl.create_table(table)
return table
@Operations.implementation_for(ops.RenameTableOp)
def rename_table(
operations: "Operations", operation: "ops.RenameTableOp"
) -> None:
operations.impl.rename_table(
operation.table_name, operation.new_table_name, schema=operation.schema
)
@Operations.implementation_for(ops.CreateTableCommentOp)
def create_table_comment(
operations: "Operations", operation: "ops.CreateTableCommentOp"
) -> None:
table = operation.to_table(operations.migration_context)
operations.impl.create_table_comment(table)
@Operations.implementation_for(ops.DropTableCommentOp)
def drop_table_comment(
operations: "Operations", operation: "ops.DropTableCommentOp"
) -> None:
table = operation.to_table(operations.migration_context)
operations.impl.drop_table_comment(table)
@Operations.implementation_for(ops.AddColumnOp)
def add_column(operations: "Operations", operation: "ops.AddColumnOp") -> None:
table_name = operation.table_name
column = operation.column
schema = operation.schema
kw = operation.kw
if column.table is not None:
column = _copy(column)
t = operations.schema_obj.table(table_name, column, schema=schema)
operations.impl.add_column(table_name, column, schema=schema, **kw)
for constraint in t.constraints:
if not isinstance(constraint, sa_schema.PrimaryKeyConstraint):
operations.impl.add_constraint(constraint)
for index in t.indexes:
operations.impl.create_index(index)
with_comment = (
operations.impl.dialect.supports_comments
and not operations.impl.dialect.inline_comments
)
comment = column.comment
if comment and with_comment:
operations.impl.create_column_comment(column)
@Operations.implementation_for(ops.AddConstraintOp)
def create_constraint(
operations: "Operations", operation: "ops.AddConstraintOp"
) -> None:
operations.impl.add_constraint(
operation.to_constraint(operations.migration_context)
)
@Operations.implementation_for(ops.DropConstraintOp)
def drop_constraint(
operations: "Operations", operation: "ops.DropConstraintOp"
) -> None:
operations.impl.drop_constraint(
operations.schema_obj.generic_constraint(
operation.constraint_name,
operation.table_name,
operation.constraint_type,
schema=operation.schema,
)
)
@Operations.implementation_for(ops.BulkInsertOp)
def bulk_insert(
operations: "Operations", operation: "ops.BulkInsertOp"
) -> None:
operations.impl.bulk_insert( # type: ignore[union-attr]
operation.table, operation.rows, multiinsert=operation.multiinsert
)
@Operations.implementation_for(ops.ExecuteSQLOp)
def execute_sql(
operations: "Operations", operation: "ops.ExecuteSQLOp"
) -> None:
operations.migration_context.impl.execute(
operation.sqltext, execution_options=operation.execution_options
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,4 @@
from .base import Script
from .base import ScriptDirectory
__all__ = ["ScriptDirectory", "Script"]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,179 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import shlex
import subprocess
import sys
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
from typing import Union
from .. import util
from ..util import compat
REVISION_SCRIPT_TOKEN = "REVISION_SCRIPT_FILENAME"
_registry: dict = {}
def register(name: str) -> Callable:
"""A function decorator that will register that function as a write hook.
See the documentation linked below for an example.
.. seealso::
:ref:`post_write_hooks_custom`
"""
def decorate(fn):
_registry[name] = fn
return fn
return decorate
def _invoke(
name: str, revision: str, options: Mapping[str, Union[str, int]]
) -> Any:
"""Invokes the formatter registered for the given name.
:param name: The name of a formatter in the registry
:param revision: A :class:`.MigrationRevision` instance
:param options: A dict containing kwargs passed to the
specified formatter.
:raises: :class:`alembic.util.CommandError`
"""
try:
hook = _registry[name]
except KeyError as ke:
raise util.CommandError(
f"No formatter with name '{name}' registered"
) from ke
else:
return hook(revision, options)
def _run_hooks(path: str, hook_config: Mapping[str, str]) -> None:
"""Invoke hooks for a generated revision."""
from .base import _split_on_space_comma
names = _split_on_space_comma.split(hook_config.get("hooks", ""))
for name in names:
if not name:
continue
opts = {
key[len(name) + 1 :]: hook_config[key]
for key in hook_config
if key.startswith(name + ".")
}
opts["_hook_name"] = name
try:
type_ = opts["type"]
except KeyError as ke:
raise util.CommandError(
f"Key {name}.type is required for post write hook {name!r}"
) from ke
else:
with util.status(
f"Running post write hook {name!r}", newline=True
):
_invoke(type_, path, opts)
def _parse_cmdline_options(cmdline_options_str: str, path: str) -> List[str]:
"""Parse options from a string into a list.
Also substitutes the revision script token with the actual filename of
the revision script.
If the revision script token doesn't occur in the options string, it is
automatically prepended.
"""
if REVISION_SCRIPT_TOKEN not in cmdline_options_str:
cmdline_options_str = REVISION_SCRIPT_TOKEN + " " + cmdline_options_str
cmdline_options_list = shlex.split(
cmdline_options_str, posix=compat.is_posix
)
cmdline_options_list = [
option.replace(REVISION_SCRIPT_TOKEN, path)
for option in cmdline_options_list
]
return cmdline_options_list
@register("console_scripts")
def console_scripts(
path: str, options: dict, ignore_output: bool = False
) -> None:
try:
entrypoint_name = options["entrypoint"]
except KeyError as ke:
raise util.CommandError(
f"Key {options['_hook_name']}.entrypoint is required for post "
f"write hook {options['_hook_name']!r}"
) from ke
for entry in compat.importlib_metadata_get("console_scripts"):
if entry.name == entrypoint_name:
impl: Any = entry
break
else:
raise util.CommandError(
f"Could not find entrypoint console_scripts.{entrypoint_name}"
)
cwd: Optional[str] = options.get("cwd", None)
cmdline_options_str = options.get("options", "")
cmdline_options_list = _parse_cmdline_options(cmdline_options_str, path)
kw: Dict[str, Any] = {}
if ignore_output:
kw["stdout"] = kw["stderr"] = subprocess.DEVNULL
subprocess.run(
[
sys.executable,
"-c",
f"import {impl.module}; {impl.module}.{impl.attr}()",
]
+ cmdline_options_list,
cwd=cwd,
**kw,
)
@register("exec")
def exec_(path: str, options: dict, ignore_output: bool = False) -> None:
try:
executable = options["executable"]
except KeyError as ke:
raise util.CommandError(
f"Key {options['_hook_name']}.executable is required for post "
f"write hook {options['_hook_name']!r}"
) from ke
cwd: Optional[str] = options.get("cwd", None)
cmdline_options_str = options.get("options", "")
cmdline_options_list = _parse_cmdline_options(cmdline_options_str, path)
kw: Dict[str, Any] = {}
if ignore_output:
kw["stdout"] = kw["stderr"] = subprocess.DEVNULL
subprocess.run(
[
executable,
*cmdline_options_list,
],
cwd=cwd,
**kw,
)

View File

@@ -0,0 +1 @@
Generic single-database configuration with an async dbapi.

View File

@@ -0,0 +1,114 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = ${script_location}
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to ${script_location}/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:${script_location}/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -0,0 +1,89 @@
import asyncio
from logging.config import fileConfig
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config
from alembic import context
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = None
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = async_engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1 @@
Generic single-database configuration.

View File

@@ -0,0 +1,116 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = ${script_location}
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to ${script_location}/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:${script_location}/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -0,0 +1,78 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = None
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,12 @@
Rudimentary multi-database configuration.
Multi-DB isn't vastly different from generic. The primary difference is that it
will run the migrations N times (depending on how many databases you have
configured), providing one engine name and associated context for each run.
That engine name will then allow the migration to restrict what runs within it to
just the appropriate migrations for that engine. You can see this behavior within
the mako template.
In the provided configuration, you'll need to have `databases` provided in
alembic's config, and an `sqlalchemy.url` provided for each engine name.

View File

@@ -0,0 +1,121 @@
# a multi-database configuration.
[alembic]
# path to migration scripts
script_location = ${script_location}
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to ${script_location}/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:${script_location}/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
databases = engine1, engine2
[engine1]
sqlalchemy.url = driver://user:pass@localhost/dbname
[engine2]
sqlalchemy.url = driver://user:pass@localhost/dbname2
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -0,0 +1,140 @@
import logging
from logging.config import fileConfig
import re
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
USE_TWOPHASE = False
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
logger = logging.getLogger("alembic.env")
# gather section names referring to different
# databases. These are named "engine1", "engine2"
# in the sample .ini file.
db_names = config.get_main_option("databases", "")
# add your model's MetaData objects here
# for 'autogenerate' support. These must be set
# up to hold just those tables targeting a
# particular database. table.tometadata() may be
# helpful here in case a "copy" of
# a MetaData is needed.
# from myapp import mymodel
# target_metadata = {
# 'engine1':mymodel.metadata1,
# 'engine2':mymodel.metadata2
# }
target_metadata = {}
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
# for the --sql use case, run migrations for each URL into
# individual files.
engines = {}
for name in re.split(r",\s*", db_names):
engines[name] = rec = {}
rec["url"] = context.config.get_section_option(name, "sqlalchemy.url")
for name, rec in engines.items():
logger.info("Migrating database %s" % name)
file_ = "%s.sql" % name
logger.info("Writing output to %s" % file_)
with open(file_, "w") as buffer:
context.configure(
url=rec["url"],
output_buffer=buffer,
target_metadata=target_metadata.get(name),
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations(engine_name=name)
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
# for the direct-to-DB use case, start a transaction on all
# engines, then run all migrations, then commit all transactions.
engines = {}
for name in re.split(r",\s*", db_names):
engines[name] = rec = {}
rec["engine"] = engine_from_config(
context.config.get_section(name, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
for name, rec in engines.items():
engine = rec["engine"]
rec["connection"] = conn = engine.connect()
if USE_TWOPHASE:
rec["transaction"] = conn.begin_twophase()
else:
rec["transaction"] = conn.begin()
try:
for name, rec in engines.items():
logger.info("Migrating database %s" % name)
context.configure(
connection=rec["connection"],
upgrade_token="%s_upgrades" % name,
downgrade_token="%s_downgrades" % name,
target_metadata=target_metadata.get(name),
)
context.run_migrations(engine_name=name)
if USE_TWOPHASE:
for rec in engines.values():
rec["transaction"].prepare()
for rec in engines.values():
rec["transaction"].commit()
except:
for rec in engines.values():
rec["transaction"].rollback()
raise
finally:
for rec in engines.values():
rec["connection"].close()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,47 @@
<%!
import re
%>"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade(engine_name: str) -> None:
globals()["upgrade_%s" % engine_name]()
def downgrade(engine_name: str) -> None:
globals()["downgrade_%s" % engine_name]()
<%
db_names = config.get_main_option("databases")
%>
## generate an "upgrade_<xyz>() / downgrade_<xyz>()" function
## for each database name in the ini file.
% for db_name in re.split(r',\s*', db_names):
def upgrade_${db_name}() -> None:
${context.get("%s_upgrades" % db_name, "pass")}
def downgrade_${db_name}() -> None:
${context.get("%s_downgrades" % db_name, "pass")}
% endfor

View File

@@ -0,0 +1,29 @@
from sqlalchemy.testing import config
from sqlalchemy.testing import emits_warning
from sqlalchemy.testing import engines
from sqlalchemy.testing import exclusions
from sqlalchemy.testing import mock
from sqlalchemy.testing import provide_metadata
from sqlalchemy.testing import skip_if
from sqlalchemy.testing import uses_deprecated
from sqlalchemy.testing.config import combinations
from sqlalchemy.testing.config import fixture
from sqlalchemy.testing.config import requirements as requires
from .assertions import assert_raises
from .assertions import assert_raises_message
from .assertions import emits_python_deprecation_warning
from .assertions import eq_
from .assertions import eq_ignore_whitespace
from .assertions import expect_raises
from .assertions import expect_raises_message
from .assertions import expect_sqlalchemy_deprecated
from .assertions import expect_sqlalchemy_deprecated_20
from .assertions import expect_warnings
from .assertions import is_
from .assertions import is_false
from .assertions import is_not_
from .assertions import is_true
from .assertions import ne_
from .fixtures import TestBase
from .util import resolve_lambda

View File

@@ -0,0 +1,167 @@
from __future__ import annotations
import contextlib
import re
import sys
from typing import Any
from typing import Dict
from sqlalchemy import exc as sa_exc
from sqlalchemy.engine import default
from sqlalchemy.testing.assertions import _expect_warnings
from sqlalchemy.testing.assertions import eq_ # noqa
from sqlalchemy.testing.assertions import is_ # noqa
from sqlalchemy.testing.assertions import is_false # noqa
from sqlalchemy.testing.assertions import is_not_ # noqa
from sqlalchemy.testing.assertions import is_true # noqa
from sqlalchemy.testing.assertions import ne_ # noqa
from sqlalchemy.util import decorator
from ..util import sqla_compat
def _assert_proper_exception_context(exception):
"""assert that any exception we're catching does not have a __context__
without a __cause__, and that __suppress_context__ is never set.
Python 3 will report nested as exceptions as "during the handling of
error X, error Y occurred". That's not what we want to do. we want
these exceptions in a cause chain.
"""
if (
exception.__context__ is not exception.__cause__
and not exception.__suppress_context__
):
assert False, (
"Exception %r was correctly raised but did not set a cause, "
"within context %r as its cause."
% (exception, exception.__context__)
)
def assert_raises(except_cls, callable_, *args, **kw):
return _assert_raises(except_cls, callable_, args, kw, check_context=True)
def assert_raises_context_ok(except_cls, callable_, *args, **kw):
return _assert_raises(except_cls, callable_, args, kw)
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
return _assert_raises(
except_cls, callable_, args, kwargs, msg=msg, check_context=True
)
def assert_raises_message_context_ok(
except_cls, msg, callable_, *args, **kwargs
):
return _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
def _assert_raises(
except_cls, callable_, args, kwargs, msg=None, check_context=False
):
with _expect_raises(except_cls, msg, check_context) as ec:
callable_(*args, **kwargs)
return ec.error
class _ErrorContainer:
error: Any = None
@contextlib.contextmanager
def _expect_raises(except_cls, msg=None, check_context=False):
ec = _ErrorContainer()
if check_context:
are_we_already_in_a_traceback = sys.exc_info()[0]
try:
yield ec
success = False
except except_cls as err:
ec.error = err
success = True
if msg is not None:
assert re.search(msg, str(err), re.UNICODE), f"{msg} !~ {err}"
if check_context and not are_we_already_in_a_traceback:
_assert_proper_exception_context(err)
print(str(err).encode("utf-8"))
# assert outside the block so it works for AssertionError too !
assert success, "Callable did not raise an exception"
def expect_raises(except_cls, check_context=True):
return _expect_raises(except_cls, check_context=check_context)
def expect_raises_message(except_cls, msg, check_context=True):
return _expect_raises(except_cls, msg=msg, check_context=check_context)
def eq_ignore_whitespace(a, b, msg=None):
a = re.sub(r"^\s+?|\n", "", a)
a = re.sub(r" {2,}", " ", a)
b = re.sub(r"^\s+?|\n", "", b)
b = re.sub(r" {2,}", " ", b)
assert a == b, msg or "%r != %r" % (a, b)
_dialect_mods: Dict[Any, Any] = {}
def _get_dialect(name):
if name is None or name == "default":
return default.DefaultDialect()
else:
d = sqla_compat._create_url(name).get_dialect()()
if name == "postgresql":
d.implicit_returning = True
elif name == "mssql":
d.legacy_schema_aliasing = False
return d
def expect_warnings(*messages, **kw):
"""Context manager which expects one or more warnings.
With no arguments, squelches all SAWarnings emitted via
sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
pass string expressions that will match selected warnings via regex;
all non-matching warnings are sent through.
The expect version **asserts** that the warnings were in fact seen.
Note that the test suite sets SAWarning warnings to raise exceptions.
"""
return _expect_warnings(Warning, messages, **kw)
def emits_python_deprecation_warning(*messages):
"""Decorator form of expect_warnings().
Note that emits_warning does **not** assert that the warnings
were in fact seen.
"""
@decorator
def decorate(fn, *args, **kw):
with _expect_warnings(DeprecationWarning, assert_=False, *messages):
return fn(*args, **kw)
return decorate
def expect_sqlalchemy_deprecated(*messages, **kw):
return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
def expect_sqlalchemy_deprecated_20(*messages, **kw):
return _expect_warnings(sa_exc.RemovedIn20Warning, messages, **kw)

View File

@@ -0,0 +1,518 @@
import importlib.machinery
import os
import shutil
import textwrap
from sqlalchemy.testing import config
from sqlalchemy.testing import provision
from . import util as testing_util
from .. import command
from .. import script
from .. import util
from ..script import Script
from ..script import ScriptDirectory
def _get_staging_directory():
if provision.FOLLOWER_IDENT:
return "scratch_%s" % provision.FOLLOWER_IDENT
else:
return "scratch"
def staging_env(create=True, template="generic", sourceless=False):
cfg = _testing_config()
if create:
path = os.path.join(_get_staging_directory(), "scripts")
assert not os.path.exists(path), (
"staging directory %s already exists; poor cleanup?" % path
)
command.init(cfg, path, template=template)
if sourceless:
try:
# do an import so that a .pyc/.pyo is generated.
util.load_python_file(path, "env.py")
except AttributeError:
# we don't have the migration context set up yet
# so running the .env py throws this exception.
# theoretically we could be using py_compiler here to
# generate .pyc/.pyo without importing but not really
# worth it.
pass
assert sourceless in (
"pep3147_envonly",
"simple",
"pep3147_everything",
), sourceless
make_sourceless(
os.path.join(path, "env.py"),
"pep3147" if "pep3147" in sourceless else "simple",
)
sc = script.ScriptDirectory.from_config(cfg)
return sc
def clear_staging_env():
from sqlalchemy.testing import engines
engines.testing_reaper.close_all()
shutil.rmtree(_get_staging_directory(), True)
def script_file_fixture(txt):
dir_ = os.path.join(_get_staging_directory(), "scripts")
path = os.path.join(dir_, "script.py.mako")
with open(path, "w") as f:
f.write(txt)
def env_file_fixture(txt):
dir_ = os.path.join(_get_staging_directory(), "scripts")
txt = (
"""
from alembic import context
config = context.config
"""
+ txt
)
path = os.path.join(dir_, "env.py")
pyc_path = util.pyc_file_from_path(path)
if pyc_path:
os.unlink(pyc_path)
with open(path, "w") as f:
f.write(txt)
def _sqlite_file_db(tempname="foo.db", future=False, scope=None, **options):
dir_ = os.path.join(_get_staging_directory(), "scripts")
url = "sqlite:///%s/%s" % (dir_, tempname)
if scope and util.sqla_14:
options["scope"] = scope
return testing_util.testing_engine(url=url, future=future, options=options)
def _sqlite_testing_config(sourceless=False, future=False):
dir_ = os.path.join(_get_staging_directory(), "scripts")
url = "sqlite:///%s/foo.db" % dir_
sqlalchemy_future = future or ("future" in config.db.__class__.__module__)
return _write_config_file(
"""
[alembic]
script_location = %s
sqlalchemy.url = %s
sourceless = %s
%s
[loggers]
keys = root,sqlalchemy
[handlers]
keys = console
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = DEBUG
handlers =
qualname = sqlalchemy.engine
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatters]
keys = generic
[formatter_generic]
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
"""
% (
dir_,
url,
"true" if sourceless else "false",
"sqlalchemy.future = true" if sqlalchemy_future else "",
)
)
def _multi_dir_testing_config(sourceless=False, extra_version_location=""):
dir_ = os.path.join(_get_staging_directory(), "scripts")
sqlalchemy_future = "future" in config.db.__class__.__module__
url = "sqlite:///%s/foo.db" % dir_
return _write_config_file(
"""
[alembic]
script_location = %s
sqlalchemy.url = %s
sqlalchemy.future = %s
sourceless = %s
version_locations = %%(here)s/model1/ %%(here)s/model2/ %%(here)s/model3/ %s
[loggers]
keys = root
[handlers]
keys = console
[logger_root]
level = WARN
handlers = console
qualname =
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatters]
keys = generic
[formatter_generic]
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
"""
% (
dir_,
url,
"true" if sqlalchemy_future else "false",
"true" if sourceless else "false",
extra_version_location,
)
)
def _no_sql_testing_config(dialect="postgresql", directives=""):
"""use a postgresql url with no host so that
connections guaranteed to fail"""
dir_ = os.path.join(_get_staging_directory(), "scripts")
return _write_config_file(
"""
[alembic]
script_location = %s
sqlalchemy.url = %s://
%s
[loggers]
keys = root
[handlers]
keys = console
[logger_root]
level = WARN
handlers = console
qualname =
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatters]
keys = generic
[formatter_generic]
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
"""
% (dir_, dialect, directives)
)
def _write_config_file(text):
cfg = _testing_config()
with open(cfg.config_file_name, "w") as f:
f.write(text)
return cfg
def _testing_config():
from alembic.config import Config
if not os.access(_get_staging_directory(), os.F_OK):
os.mkdir(_get_staging_directory())
return Config(os.path.join(_get_staging_directory(), "test_alembic.ini"))
def write_script(
scriptdir, rev_id, content, encoding="ascii", sourceless=False
):
old = scriptdir.revision_map.get_revision(rev_id)
path = old.path
content = textwrap.dedent(content)
if encoding:
content = content.encode(encoding)
with open(path, "wb") as fp:
fp.write(content)
pyc_path = util.pyc_file_from_path(path)
if pyc_path:
os.unlink(pyc_path)
script = Script._from_path(scriptdir, path)
old = scriptdir.revision_map.get_revision(script.revision)
if old.down_revision != script.down_revision:
raise Exception(
"Can't change down_revision " "on a refresh operation."
)
scriptdir.revision_map.add_revision(script, _replace=True)
if sourceless:
make_sourceless(
path, "pep3147" if sourceless == "pep3147_everything" else "simple"
)
def make_sourceless(path, style):
import py_compile
py_compile.compile(path)
if style == "simple":
pyc_path = util.pyc_file_from_path(path)
suffix = importlib.machinery.BYTECODE_SUFFIXES[0]
filepath, ext = os.path.splitext(path)
simple_pyc_path = filepath + suffix
shutil.move(pyc_path, simple_pyc_path)
pyc_path = simple_pyc_path
else:
assert style in ("pep3147", "simple")
pyc_path = util.pyc_file_from_path(path)
assert os.access(pyc_path, os.F_OK)
os.unlink(path)
def three_rev_fixture(cfg):
a = util.rev_id()
b = util.rev_id()
c = util.rev_id()
script = ScriptDirectory.from_config(cfg)
script.generate_revision(a, "revision a", refresh=True, head="base")
write_script(
script,
a,
"""\
"Rev A"
revision = '%s'
down_revision = None
from alembic import op
def upgrade():
op.execute("CREATE STEP 1")
def downgrade():
op.execute("DROP STEP 1")
"""
% a,
)
script.generate_revision(b, "revision b", refresh=True, head=a)
write_script(
script,
b,
f"""# coding: utf-8
"Rev B, méil, %3"
revision = '{b}'
down_revision = '{a}'
from alembic import op
def upgrade():
op.execute("CREATE STEP 2")
def downgrade():
op.execute("DROP STEP 2")
""",
encoding="utf-8",
)
script.generate_revision(c, "revision c", refresh=True, head=b)
write_script(
script,
c,
"""\
"Rev C"
revision = '%s'
down_revision = '%s'
from alembic import op
def upgrade():
op.execute("CREATE STEP 3")
def downgrade():
op.execute("DROP STEP 3")
"""
% (c, b),
)
return a, b, c
def multi_heads_fixture(cfg, a, b, c):
"""Create a multiple head fixture from the three-revs fixture"""
# a->b->c
# -> d -> e
# -> f
d = util.rev_id()
e = util.rev_id()
f = util.rev_id()
script = ScriptDirectory.from_config(cfg)
script.generate_revision(
d, "revision d from b", head=b, splice=True, refresh=True
)
write_script(
script,
d,
"""\
"Rev D"
revision = '%s'
down_revision = '%s'
from alembic import op
def upgrade():
op.execute("CREATE STEP 4")
def downgrade():
op.execute("DROP STEP 4")
"""
% (d, b),
)
script.generate_revision(
e, "revision e from d", head=d, splice=True, refresh=True
)
write_script(
script,
e,
"""\
"Rev E"
revision = '%s'
down_revision = '%s'
from alembic import op
def upgrade():
op.execute("CREATE STEP 5")
def downgrade():
op.execute("DROP STEP 5")
"""
% (e, d),
)
script.generate_revision(
f, "revision f from b", head=b, splice=True, refresh=True
)
write_script(
script,
f,
"""\
"Rev F"
revision = '%s'
down_revision = '%s'
from alembic import op
def upgrade():
op.execute("CREATE STEP 6")
def downgrade():
op.execute("DROP STEP 6")
"""
% (f, b),
)
return d, e, f
def _multidb_testing_config(engines):
"""alembic.ini fixture to work exactly with the 'multidb' template"""
dir_ = os.path.join(_get_staging_directory(), "scripts")
sqlalchemy_future = "future" in config.db.__class__.__module__
databases = ", ".join(engines.keys())
engines = "\n\n".join(
"[%s]\n" "sqlalchemy.url = %s" % (key, value.url)
for key, value in engines.items()
)
return _write_config_file(
"""
[alembic]
script_location = %s
sourceless = false
sqlalchemy.future = %s
databases = %s
%s
[loggers]
keys = root
[handlers]
keys = console
[logger_root]
level = WARN
handlers = console
qualname =
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatters]
keys = generic
[formatter_generic]
format = %%(levelname)-5.5s [%%(name)s] %%(message)s
datefmt = %%H:%%M:%%S
"""
% (dir_, "true" if sqlalchemy_future else "false", databases, engines)
)

View File

@@ -0,0 +1,306 @@
from __future__ import annotations
import configparser
from contextlib import contextmanager
import io
import re
from typing import Any
from typing import Dict
from sqlalchemy import Column
from sqlalchemy import inspect
from sqlalchemy import MetaData
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import testing
from sqlalchemy import text
from sqlalchemy.testing import config
from sqlalchemy.testing import mock
from sqlalchemy.testing.assertions import eq_
from sqlalchemy.testing.fixtures import TablesTest as SQLAlchemyTablesTest
from sqlalchemy.testing.fixtures import TestBase as SQLAlchemyTestBase
import alembic
from .assertions import _get_dialect
from ..environment import EnvironmentContext
from ..migration import MigrationContext
from ..operations import Operations
from ..util import sqla_compat
from ..util.sqla_compat import create_mock_engine
from ..util.sqla_compat import sqla_14
from ..util.sqla_compat import sqla_2
testing_config = configparser.ConfigParser()
testing_config.read(["test.cfg"])
class TestBase(SQLAlchemyTestBase):
is_sqlalchemy_future = sqla_2
@testing.fixture()
def ops_context(self, migration_context):
with migration_context.begin_transaction(_per_migration=True):
yield Operations(migration_context)
@testing.fixture
def migration_context(self, connection):
return MigrationContext.configure(
connection, opts=dict(transaction_per_migration=True)
)
@testing.fixture
def connection(self):
with config.db.connect() as conn:
yield conn
class TablesTest(TestBase, SQLAlchemyTablesTest):
pass
if sqla_14:
from sqlalchemy.testing.fixtures import FutureEngineMixin
else:
class FutureEngineMixin: # type:ignore[no-redef]
__requires__ = ("sqlalchemy_14",)
FutureEngineMixin.is_sqlalchemy_future = True
def capture_db(dialect="postgresql://"):
buf = []
def dump(sql, *multiparams, **params):
buf.append(str(sql.compile(dialect=engine.dialect)))
engine = create_mock_engine(dialect, dump)
return engine, buf
_engs: Dict[Any, Any] = {}
@contextmanager
def capture_context_buffer(**kw):
if kw.pop("bytes_io", False):
buf = io.BytesIO()
else:
buf = io.StringIO()
kw.update({"dialect_name": "sqlite", "output_buffer": buf})
conf = EnvironmentContext.configure
def configure(*arg, **opt):
opt.update(**kw)
return conf(*arg, **opt)
with mock.patch.object(EnvironmentContext, "configure", configure):
yield buf
@contextmanager
def capture_engine_context_buffer(**kw):
from .env import _sqlite_file_db
from sqlalchemy import event
buf = io.StringIO()
eng = _sqlite_file_db()
conn = eng.connect()
@event.listens_for(conn, "before_cursor_execute")
def bce(conn, cursor, statement, parameters, context, executemany):
buf.write(statement + "\n")
kw.update({"connection": conn})
conf = EnvironmentContext.configure
def configure(*arg, **opt):
opt.update(**kw)
return conf(*arg, **opt)
with mock.patch.object(EnvironmentContext, "configure", configure):
yield buf
def op_fixture(
dialect="default",
as_sql=False,
naming_convention=None,
literal_binds=False,
native_boolean=None,
):
opts = {}
if naming_convention:
opts["target_metadata"] = MetaData(naming_convention=naming_convention)
class buffer_:
def __init__(self):
self.lines = []
def write(self, msg):
msg = msg.strip()
msg = re.sub(r"[\n\t]", "", msg)
if as_sql:
# the impl produces soft tabs,
# so search for blocks of 4 spaces
msg = re.sub(r" ", "", msg)
msg = re.sub(r"\;\n*$", "", msg)
self.lines.append(msg)
def flush(self):
pass
buf = buffer_()
class ctx(MigrationContext):
def get_buf(self):
return buf
def clear_assertions(self):
buf.lines[:] = []
def assert_(self, *sql):
# TODO: make this more flexible about
# whitespace and such
eq_(buf.lines, [re.sub(r"[\n\t]", "", s) for s in sql])
def assert_contains(self, sql):
for stmt in buf.lines:
if re.sub(r"[\n\t]", "", sql) in stmt:
return
else:
assert False, "Could not locate fragment %r in %r" % (
sql,
buf.lines,
)
if as_sql:
opts["as_sql"] = as_sql
if literal_binds:
opts["literal_binds"] = literal_binds
if not sqla_14 and dialect == "mariadb":
ctx_dialect = _get_dialect("mysql")
ctx_dialect.server_version_info = (10, 4, 0, "MariaDB")
else:
ctx_dialect = _get_dialect(dialect)
if native_boolean is not None:
ctx_dialect.supports_native_boolean = native_boolean
# this is new as of SQLAlchemy 1.2.7 and is used by SQL Server,
# which breaks assumptions in the alembic test suite
ctx_dialect.non_native_boolean_check_constraint = True
if not as_sql:
def execute(stmt, *multiparam, **param):
if isinstance(stmt, str):
stmt = text(stmt)
assert stmt.supports_execution
sql = str(stmt.compile(dialect=ctx_dialect))
buf.write(sql)
connection = mock.Mock(dialect=ctx_dialect, execute=execute)
else:
opts["output_buffer"] = buf
connection = None
context = ctx(ctx_dialect, connection, opts)
alembic.op._proxy = Operations(context)
return context
class AlterColRoundTripFixture:
# since these tests are about syntax, use more recent SQLAlchemy as some of
# the type / server default compare logic might not work on older
# SQLAlchemy versions as seems to be the case for SQLAlchemy 1.1 on Oracle
__requires__ = ("alter_column",)
def setUp(self):
self.conn = config.db.connect()
self.ctx = MigrationContext.configure(self.conn)
self.op = Operations(self.ctx)
self.metadata = MetaData()
def _compare_type(self, t1, t2):
c1 = Column("q", t1)
c2 = Column("q", t2)
assert not self.ctx.impl.compare_type(
c1, c2
), "Type objects %r and %r didn't compare as equivalent" % (t1, t2)
def _compare_server_default(self, t1, s1, t2, s2):
c1 = Column("q", t1, server_default=s1)
c2 = Column("q", t2, server_default=s2)
assert not self.ctx.impl.compare_server_default(
c1, c2, s2, s1
), "server defaults %r and %r didn't compare as equivalent" % (s1, s2)
def tearDown(self):
sqla_compat._safe_rollback_connection_transaction(self.conn)
with self.conn.begin():
self.metadata.drop_all(self.conn)
self.conn.close()
def _run_alter_col(self, from_, to_, compare=None):
column = Column(
from_.get("name", "colname"),
from_.get("type", String(10)),
nullable=from_.get("nullable", True),
server_default=from_.get("server_default", None),
# comment=from_.get("comment", None)
)
t = Table("x", self.metadata, column)
with sqla_compat._ensure_scope_for_ddl(self.conn):
t.create(self.conn)
insp = inspect(self.conn)
old_col = insp.get_columns("x")[0]
# TODO: conditional comment support
self.op.alter_column(
"x",
column.name,
existing_type=column.type,
existing_server_default=column.server_default
if column.server_default is not None
else False,
existing_nullable=True if column.nullable else False,
# existing_comment=column.comment,
nullable=to_.get("nullable", None),
# modify_comment=False,
server_default=to_.get("server_default", False),
new_column_name=to_.get("name", None),
type_=to_.get("type", None),
)
insp = inspect(self.conn)
new_col = insp.get_columns("x")[0]
if compare is None:
compare = to_
eq_(
new_col["name"],
compare["name"] if "name" in compare else column.name,
)
self._compare_type(
new_col["type"], compare.get("type", old_col["type"])
)
eq_(new_col["nullable"], compare.get("nullable", column.nullable))
self._compare_server_default(
new_col["type"],
new_col.get("default", None),
compare.get("type", old_col["type"]),
compare["server_default"].text
if "server_default" in compare
else column.server_default.arg.text
if column.server_default is not None
else None,
)

View File

@@ -0,0 +1,4 @@
"""
Bootstrapper for test framework plugins.
"""

View File

@@ -0,0 +1,210 @@
from sqlalchemy.testing.requirements import Requirements
from alembic import util
from alembic.util import sqla_compat
from ..testing import exclusions
class SuiteRequirements(Requirements):
@property
def schemas(self):
"""Target database must support external schemas, and have one
named 'test_schema'."""
return exclusions.open()
@property
def autocommit_isolation(self):
"""target database should support 'AUTOCOMMIT' isolation level"""
return exclusions.closed()
@property
def materialized_views(self):
"""needed for sqlalchemy compat"""
return exclusions.closed()
@property
def unique_constraint_reflection(self):
def doesnt_have_check_uq_constraints(config):
from sqlalchemy import inspect
insp = inspect(config.db)
try:
insp.get_unique_constraints("x")
except NotImplementedError:
return True
except TypeError:
return True
except Exception:
pass
return False
return exclusions.skip_if(doesnt_have_check_uq_constraints)
@property
def sequences(self):
"""Target database must support SEQUENCEs."""
return exclusions.only_if(
[lambda config: config.db.dialect.supports_sequences],
"no sequence support",
)
@property
def foreign_key_match(self):
return exclusions.open()
@property
def foreign_key_constraint_reflection(self):
return exclusions.open()
@property
def check_constraints_w_enforcement(self):
"""Target database must support check constraints
and also enforce them."""
return exclusions.open()
@property
def reflects_pk_names(self):
return exclusions.closed()
@property
def reflects_fk_options(self):
return exclusions.closed()
@property
def sqlalchemy_14(self):
return exclusions.skip_if(
lambda config: not util.sqla_14,
"SQLAlchemy 1.4 or greater required",
)
@property
def sqlalchemy_1x(self):
return exclusions.skip_if(
lambda config: util.sqla_2,
"SQLAlchemy 1.x test",
)
@property
def sqlalchemy_2(self):
return exclusions.skip_if(
lambda config: not util.sqla_2,
"SQLAlchemy 2.x test",
)
@property
def asyncio(self):
def go(config):
try:
import greenlet # noqa: F401
except ImportError:
return False
else:
return True
return self.sqlalchemy_14 + exclusions.only_if(go)
@property
def comments(self):
return exclusions.only_if(
lambda config: config.db.dialect.supports_comments
)
@property
def alter_column(self):
return exclusions.open()
@property
def computed_columns(self):
return exclusions.closed()
@property
def computed_columns_api(self):
return exclusions.only_if(
exclusions.BooleanPredicate(sqla_compat.has_computed)
)
@property
def computed_reflects_normally(self):
return exclusions.only_if(
exclusions.BooleanPredicate(sqla_compat.has_computed_reflection)
)
@property
def computed_reflects_as_server_default(self):
return exclusions.closed()
@property
def computed_doesnt_reflect_as_server_default(self):
return exclusions.closed()
@property
def autoincrement_on_composite_pk(self):
return exclusions.closed()
@property
def fk_ondelete_is_reflected(self):
return exclusions.closed()
@property
def fk_onupdate_is_reflected(self):
return exclusions.closed()
@property
def fk_onupdate(self):
return exclusions.open()
@property
def fk_ondelete_restrict(self):
return exclusions.open()
@property
def fk_onupdate_restrict(self):
return exclusions.open()
@property
def fk_ondelete_noaction(self):
return exclusions.open()
@property
def fk_initially(self):
return exclusions.closed()
@property
def fk_deferrable(self):
return exclusions.closed()
@property
def fk_deferrable_is_reflected(self):
return exclusions.closed()
@property
def fk_names(self):
return exclusions.open()
@property
def integer_subtype_comparisons(self):
return exclusions.open()
@property
def no_name_normalize(self):
return exclusions.skip_if(
lambda config: config.db.dialect.requires_name_normalize
)
@property
def identity_columns(self):
return exclusions.closed()
@property
def identity_columns_alter(self):
return exclusions.closed()
@property
def identity_columns_api(self):
return exclusions.only_if(
exclusions.BooleanPredicate(sqla_compat.has_identity)
)

View File

@@ -0,0 +1,169 @@
from itertools import zip_longest
from sqlalchemy import schema
from sqlalchemy.sql.elements import ClauseList
class CompareTable:
def __init__(self, table):
self.table = table
def __eq__(self, other):
if self.table.name != other.name or self.table.schema != other.schema:
return False
for c1, c2 in zip_longest(self.table.c, other.c):
if (c1 is None and c2 is not None) or (
c2 is None and c1 is not None
):
return False
if CompareColumn(c1) != c2:
return False
return True
# TODO: compare constraints, indexes
def __ne__(self, other):
return not self.__eq__(other)
class CompareColumn:
def __init__(self, column):
self.column = column
def __eq__(self, other):
return (
self.column.name == other.name
and self.column.nullable == other.nullable
)
# TODO: datatypes etc
def __ne__(self, other):
return not self.__eq__(other)
class CompareIndex:
def __init__(self, index, name_only=False):
self.index = index
self.name_only = name_only
def __eq__(self, other):
if self.name_only:
return self.index.name == other.name
else:
return (
str(schema.CreateIndex(self.index))
== str(schema.CreateIndex(other))
and self.index.dialect_kwargs == other.dialect_kwargs
)
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
expr = ClauseList(*self.index.expressions)
try:
expr_str = expr.compile().string
except Exception:
expr_str = str(expr)
return f"<CompareIndex {self.index.name}({expr_str})>"
class CompareCheckConstraint:
def __init__(self, constraint):
self.constraint = constraint
def __eq__(self, other):
return (
isinstance(other, schema.CheckConstraint)
and self.constraint.name == other.name
and (str(self.constraint.sqltext) == str(other.sqltext))
and (other.table.name == self.constraint.table.name)
and other.table.schema == self.constraint.table.schema
)
def __ne__(self, other):
return not self.__eq__(other)
class CompareForeignKey:
def __init__(self, constraint):
self.constraint = constraint
def __eq__(self, other):
r1 = (
isinstance(other, schema.ForeignKeyConstraint)
and self.constraint.name == other.name
and (other.table.name == self.constraint.table.name)
and other.table.schema == self.constraint.table.schema
)
if not r1:
return False
for c1, c2 in zip_longest(self.constraint.columns, other.columns):
if (c1 is None and c2 is not None) or (
c2 is None and c1 is not None
):
return False
if CompareColumn(c1) != c2:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
class ComparePrimaryKey:
def __init__(self, constraint):
self.constraint = constraint
def __eq__(self, other):
r1 = (
isinstance(other, schema.PrimaryKeyConstraint)
and self.constraint.name == other.name
and (other.table.name == self.constraint.table.name)
and other.table.schema == self.constraint.table.schema
)
if not r1:
return False
for c1, c2 in zip_longest(self.constraint.columns, other.columns):
if (c1 is None and c2 is not None) or (
c2 is None and c1 is not None
):
return False
if CompareColumn(c1) != c2:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
class CompareUniqueConstraint:
def __init__(self, constraint):
self.constraint = constraint
def __eq__(self, other):
r1 = (
isinstance(other, schema.UniqueConstraint)
and self.constraint.name == other.name
and (other.table.name == self.constraint.table.name)
and other.table.schema == self.constraint.table.schema
)
if not r1:
return False
for c1, c2 in zip_longest(self.constraint.columns, other.columns):
if (c1 is None and c2 is not None) or (
c2 is None and c1 is not None
):
return False
if CompareColumn(c1) != c2:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)

View File

@@ -0,0 +1,7 @@
from .test_autogen_comments import * # noqa
from .test_autogen_computed import * # noqa
from .test_autogen_diffs import * # noqa
from .test_autogen_fks import * # noqa
from .test_autogen_identity import * # noqa
from .test_environment import * # noqa
from .test_op import * # noqa

View File

@@ -0,0 +1,335 @@
from __future__ import annotations
from typing import Any
from typing import Dict
from typing import Set
from sqlalchemy import CHAR
from sqlalchemy import CheckConstraint
from sqlalchemy import Column
from sqlalchemy import event
from sqlalchemy import ForeignKey
from sqlalchemy import Index
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import Numeric
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import Text
from sqlalchemy import text
from sqlalchemy import UniqueConstraint
from ... import autogenerate
from ... import util
from ...autogenerate import api
from ...ddl.base import _fk_spec
from ...migration import MigrationContext
from ...operations import ops
from ...testing import config
from ...testing import eq_
from ...testing.env import clear_staging_env
from ...testing.env import staging_env
names_in_this_test: Set[Any] = set()
@event.listens_for(Table, "after_parent_attach")
def new_table(table, parent):
names_in_this_test.add(table.name)
def _default_include_object(obj, name, type_, reflected, compare_to):
if type_ == "table":
return name in names_in_this_test
else:
return True
_default_object_filters: Any = _default_include_object
_default_name_filters: Any = None
class ModelOne:
__requires__ = ("unique_constraint_reflection",)
schema: Any = None
@classmethod
def _get_db_schema(cls):
schema = cls.schema
m = MetaData(schema=schema)
Table(
"user",
m,
Column("id", Integer, primary_key=True),
Column("name", String(50)),
Column("a1", Text),
Column("pw", String(50)),
Index("pw_idx", "pw"),
)
Table(
"address",
m,
Column("id", Integer, primary_key=True),
Column("email_address", String(100), nullable=False),
)
Table(
"order",
m,
Column("order_id", Integer, primary_key=True),
Column(
"amount",
Numeric(8, 2),
nullable=False,
server_default=text("0"),
),
CheckConstraint("amount >= 0", name="ck_order_amount"),
)
Table(
"extra",
m,
Column("x", CHAR),
Column("uid", Integer, ForeignKey("user.id")),
)
return m
@classmethod
def _get_model_schema(cls):
schema = cls.schema
m = MetaData(schema=schema)
Table(
"user",
m,
Column("id", Integer, primary_key=True),
Column("name", String(50), nullable=False),
Column("a1", Text, server_default="x"),
)
Table(
"address",
m,
Column("id", Integer, primary_key=True),
Column("email_address", String(100), nullable=False),
Column("street", String(50)),
UniqueConstraint("email_address", name="uq_email"),
)
Table(
"order",
m,
Column("order_id", Integer, primary_key=True),
Column(
"amount",
Numeric(10, 2),
nullable=True,
server_default=text("0"),
),
Column("user_id", Integer, ForeignKey("user.id")),
CheckConstraint("amount > -1", name="ck_order_amount"),
)
Table(
"item",
m,
Column("id", Integer, primary_key=True),
Column("description", String(100)),
Column("order_id", Integer, ForeignKey("order.order_id")),
CheckConstraint("len(description) > 5"),
)
return m
class _ComparesFKs:
def _assert_fk_diff(
self,
diff,
type_,
source_table,
source_columns,
target_table,
target_columns,
name=None,
conditional_name=None,
source_schema=None,
onupdate=None,
ondelete=None,
initially=None,
deferrable=None,
):
# the public API for ForeignKeyConstraint was not very rich
# in 0.7, 0.8, so here we use the well-known but slightly
# private API to get at its elements
(
fk_source_schema,
fk_source_table,
fk_source_columns,
fk_target_schema,
fk_target_table,
fk_target_columns,
fk_onupdate,
fk_ondelete,
fk_deferrable,
fk_initially,
) = _fk_spec(diff[1])
eq_(diff[0], type_)
eq_(fk_source_table, source_table)
eq_(fk_source_columns, source_columns)
eq_(fk_target_table, target_table)
eq_(fk_source_schema, source_schema)
eq_(fk_onupdate, onupdate)
eq_(fk_ondelete, ondelete)
eq_(fk_initially, initially)
eq_(fk_deferrable, deferrable)
eq_([elem.column.name for elem in diff[1].elements], target_columns)
if conditional_name is not None:
if conditional_name == "servergenerated":
fks = inspect(self.bind).get_foreign_keys(source_table)
server_fk_name = fks[0]["name"]
eq_(diff[1].name, server_fk_name)
else:
eq_(diff[1].name, conditional_name)
else:
eq_(diff[1].name, name)
class AutogenTest(_ComparesFKs):
def _flatten_diffs(self, diffs):
for d in diffs:
if isinstance(d, list):
yield from self._flatten_diffs(d)
else:
yield d
@classmethod
def _get_bind(cls):
return config.db
configure_opts: Dict[Any, Any] = {}
@classmethod
def setup_class(cls):
staging_env()
cls.bind = cls._get_bind()
cls.m1 = cls._get_db_schema()
cls.m1.create_all(cls.bind)
cls.m2 = cls._get_model_schema()
@classmethod
def teardown_class(cls):
cls.m1.drop_all(cls.bind)
clear_staging_env()
def setUp(self):
self.conn = conn = self.bind.connect()
ctx_opts = {
"compare_type": True,
"compare_server_default": True,
"target_metadata": self.m2,
"upgrade_token": "upgrades",
"downgrade_token": "downgrades",
"alembic_module_prefix": "op.",
"sqlalchemy_module_prefix": "sa.",
"include_object": _default_object_filters,
"include_name": _default_name_filters,
}
if self.configure_opts:
ctx_opts.update(self.configure_opts)
self.context = context = MigrationContext.configure(
connection=conn, opts=ctx_opts
)
self.autogen_context = api.AutogenContext(context, self.m2)
def tearDown(self):
self.conn.close()
def _update_context(
self, object_filters=None, name_filters=None, include_schemas=None
):
if include_schemas is not None:
self.autogen_context.opts["include_schemas"] = include_schemas
if object_filters is not None:
self.autogen_context._object_filters = [object_filters]
if name_filters is not None:
self.autogen_context._name_filters = [name_filters]
return self.autogen_context
class AutogenFixtureTest(_ComparesFKs):
def _fixture(
self,
m1,
m2,
include_schemas=False,
opts=None,
object_filters=_default_object_filters,
name_filters=_default_name_filters,
return_ops=False,
max_identifier_length=None,
):
if max_identifier_length:
dialect = self.bind.dialect
existing_length = dialect.max_identifier_length
dialect.max_identifier_length = (
dialect._user_defined_max_identifier_length
) = max_identifier_length
try:
self._alembic_metadata, model_metadata = m1, m2
for m in util.to_list(self._alembic_metadata):
m.create_all(self.bind)
with self.bind.connect() as conn:
ctx_opts = {
"compare_type": True,
"compare_server_default": True,
"target_metadata": model_metadata,
"upgrade_token": "upgrades",
"downgrade_token": "downgrades",
"alembic_module_prefix": "op.",
"sqlalchemy_module_prefix": "sa.",
"include_object": object_filters,
"include_name": name_filters,
"include_schemas": include_schemas,
}
if opts:
ctx_opts.update(opts)
self.context = context = MigrationContext.configure(
connection=conn, opts=ctx_opts
)
autogen_context = api.AutogenContext(context, model_metadata)
uo = ops.UpgradeOps(ops=[])
autogenerate._produce_net_changes(autogen_context, uo)
if return_ops:
return uo
else:
return uo.as_diffs()
finally:
if max_identifier_length:
dialect = self.bind.dialect
dialect.max_identifier_length = (
dialect._user_defined_max_identifier_length
) = existing_length
def setUp(self):
staging_env()
self.bind = config.db
def tearDown(self):
if hasattr(self, "_alembic_metadata"):
for m in util.to_list(self._alembic_metadata):
m.drop_all(self.bind)
clear_staging_env()

View File

@@ -0,0 +1,242 @@
from sqlalchemy import Column
from sqlalchemy import Float
from sqlalchemy import MetaData
from sqlalchemy import String
from sqlalchemy import Table
from ._autogen_fixtures import AutogenFixtureTest
from ...testing import eq_
from ...testing import mock
from ...testing import TestBase
class AutogenerateCommentsTest(AutogenFixtureTest, TestBase):
__backend__ = True
__requires__ = ("comments",)
def test_existing_table_comment_no_change(self):
m1 = MetaData()
m2 = MetaData()
Table(
"some_table",
m1,
Column("test", String(10), primary_key=True),
comment="this is some table",
)
Table(
"some_table",
m2,
Column("test", String(10), primary_key=True),
comment="this is some table",
)
diffs = self._fixture(m1, m2)
eq_(diffs, [])
def test_add_table_comment(self):
m1 = MetaData()
m2 = MetaData()
Table("some_table", m1, Column("test", String(10), primary_key=True))
Table(
"some_table",
m2,
Column("test", String(10), primary_key=True),
comment="this is some table",
)
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "add_table_comment")
eq_(diffs[0][1].comment, "this is some table")
eq_(diffs[0][2], None)
def test_remove_table_comment(self):
m1 = MetaData()
m2 = MetaData()
Table(
"some_table",
m1,
Column("test", String(10), primary_key=True),
comment="this is some table",
)
Table("some_table", m2, Column("test", String(10), primary_key=True))
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "remove_table_comment")
eq_(diffs[0][1].comment, None)
def test_alter_table_comment(self):
m1 = MetaData()
m2 = MetaData()
Table(
"some_table",
m1,
Column("test", String(10), primary_key=True),
comment="this is some table",
)
Table(
"some_table",
m2,
Column("test", String(10), primary_key=True),
comment="this is also some table",
)
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "add_table_comment")
eq_(diffs[0][1].comment, "this is also some table")
eq_(diffs[0][2], "this is some table")
def test_existing_column_comment_no_change(self):
m1 = MetaData()
m2 = MetaData()
Table(
"some_table",
m1,
Column("test", String(10), primary_key=True),
Column("amount", Float, comment="the amount"),
)
Table(
"some_table",
m2,
Column("test", String(10), primary_key=True),
Column("amount", Float, comment="the amount"),
)
diffs = self._fixture(m1, m2)
eq_(diffs, [])
def test_add_column_comment(self):
m1 = MetaData()
m2 = MetaData()
Table(
"some_table",
m1,
Column("test", String(10), primary_key=True),
Column("amount", Float),
)
Table(
"some_table",
m2,
Column("test", String(10), primary_key=True),
Column("amount", Float, comment="the amount"),
)
diffs = self._fixture(m1, m2)
eq_(
diffs,
[
[
(
"modify_comment",
None,
"some_table",
"amount",
{
"existing_nullable": True,
"existing_type": mock.ANY,
"existing_server_default": False,
},
None,
"the amount",
)
]
],
)
def test_remove_column_comment(self):
m1 = MetaData()
m2 = MetaData()
Table(
"some_table",
m1,
Column("test", String(10), primary_key=True),
Column("amount", Float, comment="the amount"),
)
Table(
"some_table",
m2,
Column("test", String(10), primary_key=True),
Column("amount", Float),
)
diffs = self._fixture(m1, m2)
eq_(
diffs,
[
[
(
"modify_comment",
None,
"some_table",
"amount",
{
"existing_nullable": True,
"existing_type": mock.ANY,
"existing_server_default": False,
},
"the amount",
None,
)
]
],
)
def test_alter_column_comment(self):
m1 = MetaData()
m2 = MetaData()
Table(
"some_table",
m1,
Column("test", String(10), primary_key=True),
Column("amount", Float, comment="the amount"),
)
Table(
"some_table",
m2,
Column("test", String(10), primary_key=True),
Column("amount", Float, comment="the adjusted amount"),
)
diffs = self._fixture(m1, m2)
eq_(
diffs,
[
[
(
"modify_comment",
None,
"some_table",
"amount",
{
"existing_nullable": True,
"existing_type": mock.ANY,
"existing_server_default": False,
},
"the amount",
"the adjusted amount",
)
]
],
)

View File

@@ -0,0 +1,203 @@
import sqlalchemy as sa
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import Table
from ._autogen_fixtures import AutogenFixtureTest
from ... import testing
from ...testing import config
from ...testing import eq_
from ...testing import exclusions
from ...testing import is_
from ...testing import is_true
from ...testing import mock
from ...testing import TestBase
class AutogenerateComputedTest(AutogenFixtureTest, TestBase):
__requires__ = ("computed_columns",)
__backend__ = True
def test_add_computed_column(self):
m1 = MetaData()
m2 = MetaData()
Table("user", m1, Column("id", Integer, primary_key=True))
Table(
"user",
m2,
Column("id", Integer, primary_key=True),
Column("foo", Integer, sa.Computed("5")),
)
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "add_column")
eq_(diffs[0][2], "user")
eq_(diffs[0][3].name, "foo")
c = diffs[0][3].computed
is_true(isinstance(c, sa.Computed))
is_(c.persisted, None)
eq_(str(c.sqltext), "5")
def test_remove_computed_column(self):
m1 = MetaData()
m2 = MetaData()
Table(
"user",
m1,
Column("id", Integer, primary_key=True),
Column("foo", Integer, sa.Computed("5")),
)
Table("user", m2, Column("id", Integer, primary_key=True))
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "remove_column")
eq_(diffs[0][2], "user")
c = diffs[0][3]
eq_(c.name, "foo")
if config.requirements.computed_reflects_normally.enabled:
is_true(isinstance(c.computed, sa.Computed))
else:
is_(c.computed, None)
if config.requirements.computed_reflects_as_server_default.enabled:
is_true(isinstance(c.server_default, sa.DefaultClause))
eq_(str(c.server_default.arg.text), "5")
elif config.requirements.computed_reflects_normally.enabled:
is_true(isinstance(c.computed, sa.Computed))
else:
is_(c.computed, None)
@testing.combinations(
lambda: (None, sa.Computed("bar*5")),
(lambda: (sa.Computed("bar*5"), None)),
lambda: (
sa.Computed("bar*5"),
sa.Computed("bar * 42", persisted=True),
),
lambda: (sa.Computed("bar*5"), sa.Computed("bar * 42")),
)
@config.requirements.computed_reflects_normally
def test_cant_change_computed_warning(self, test_case):
arg_before, arg_after = testing.resolve_lambda(test_case, **locals())
m1 = MetaData()
m2 = MetaData()
arg_before = [] if arg_before is None else [arg_before]
arg_after = [] if arg_after is None else [arg_after]
Table(
"user",
m1,
Column("id", Integer, primary_key=True),
Column("bar", Integer),
Column("foo", Integer, *arg_before),
)
Table(
"user",
m2,
Column("id", Integer, primary_key=True),
Column("bar", Integer),
Column("foo", Integer, *arg_after),
)
with mock.patch("alembic.util.warn") as mock_warn:
diffs = self._fixture(m1, m2)
eq_(
mock_warn.mock_calls,
[mock.call("Computed default on user.foo cannot be modified")],
)
eq_(list(diffs), [])
@testing.combinations(
lambda: (None, None),
lambda: (sa.Computed("5"), sa.Computed("5")),
lambda: (sa.Computed("bar*5"), sa.Computed("bar*5")),
(
lambda: (sa.Computed("bar*5"), None),
config.requirements.computed_doesnt_reflect_as_server_default,
),
)
def test_computed_unchanged(self, test_case):
arg_before, arg_after = testing.resolve_lambda(test_case, **locals())
m1 = MetaData()
m2 = MetaData()
arg_before = [] if arg_before is None else [arg_before]
arg_after = [] if arg_after is None else [arg_after]
Table(
"user",
m1,
Column("id", Integer, primary_key=True),
Column("bar", Integer),
Column("foo", Integer, *arg_before),
)
Table(
"user",
m2,
Column("id", Integer, primary_key=True),
Column("bar", Integer),
Column("foo", Integer, *arg_after),
)
with mock.patch("alembic.util.warn") as mock_warn:
diffs = self._fixture(m1, m2)
eq_(mock_warn.mock_calls, [])
eq_(list(diffs), [])
@config.requirements.computed_reflects_as_server_default
def test_remove_computed_default_on_computed(self):
"""Asserts the current behavior which is that on PG and Oracle,
the GENERATED ALWAYS AS is reflected as a server default which we can't
tell is actually "computed", so these come out as a modification to
the server default.
"""
m1 = MetaData()
m2 = MetaData()
Table(
"user",
m1,
Column("id", Integer, primary_key=True),
Column("bar", Integer),
Column("foo", Integer, sa.Computed("bar + 42")),
)
Table(
"user",
m2,
Column("id", Integer, primary_key=True),
Column("bar", Integer),
Column("foo", Integer),
)
diffs = self._fixture(m1, m2)
eq_(diffs[0][0][0], "modify_default")
eq_(diffs[0][0][2], "user")
eq_(diffs[0][0][3], "foo")
old = diffs[0][0][-2]
new = diffs[0][0][-1]
is_(new, None)
is_true(isinstance(old, sa.DefaultClause))
if exclusions.against(config, "postgresql"):
eq_(str(old.arg.text), "(bar + 42)")
elif exclusions.against(config, "oracle"):
eq_(str(old.arg.text), '"BAR"+42')

View File

@@ -0,0 +1,273 @@
from sqlalchemy import BigInteger
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import Table
from sqlalchemy.testing import in_
from ._autogen_fixtures import AutogenFixtureTest
from ... import testing
from ...testing import config
from ...testing import eq_
from ...testing import is_
from ...testing import TestBase
class AlterColumnTest(AutogenFixtureTest, TestBase):
__backend__ = True
@testing.combinations((True,), (False,))
@config.requirements.comments
def test_all_existings_filled(self, pk):
m1 = MetaData()
m2 = MetaData()
Table("a", m1, Column("x", Integer, primary_key=pk))
Table("a", m2, Column("x", Integer, comment="x", primary_key=pk))
alter_col = self._assert_alter_col(m1, m2, pk)
eq_(alter_col.modify_comment, "x")
@testing.combinations((True,), (False,))
@config.requirements.comments
def test_all_existings_filled_in_notnull(self, pk):
m1 = MetaData()
m2 = MetaData()
Table("a", m1, Column("x", Integer, nullable=False, primary_key=pk))
Table(
"a",
m2,
Column("x", Integer, nullable=False, comment="x", primary_key=pk),
)
self._assert_alter_col(m1, m2, pk, nullable=False)
@testing.combinations((True,), (False,))
@config.requirements.comments
def test_all_existings_filled_in_comment(self, pk):
m1 = MetaData()
m2 = MetaData()
Table("a", m1, Column("x", Integer, comment="old", primary_key=pk))
Table("a", m2, Column("x", Integer, comment="new", primary_key=pk))
alter_col = self._assert_alter_col(m1, m2, pk)
eq_(alter_col.existing_comment, "old")
@testing.combinations((True,), (False,))
@config.requirements.comments
def test_all_existings_filled_in_server_default(self, pk):
m1 = MetaData()
m2 = MetaData()
Table(
"a", m1, Column("x", Integer, server_default="5", primary_key=pk)
)
Table(
"a",
m2,
Column(
"x", Integer, server_default="5", comment="new", primary_key=pk
),
)
alter_col = self._assert_alter_col(m1, m2, pk)
in_("5", alter_col.existing_server_default.arg.text)
def _assert_alter_col(self, m1, m2, pk, nullable=None):
ops = self._fixture(m1, m2, return_ops=True)
modify_table = ops.ops[-1]
alter_col = modify_table.ops[0]
if nullable is None:
eq_(alter_col.existing_nullable, not pk)
else:
eq_(alter_col.existing_nullable, nullable)
assert alter_col.existing_type._compare_type_affinity(Integer())
return alter_col
class AutoincrementTest(AutogenFixtureTest, TestBase):
__backend__ = True
__requires__ = ("integer_subtype_comparisons",)
def test_alter_column_autoincrement_none(self):
m1 = MetaData()
m2 = MetaData()
Table("a", m1, Column("x", Integer, nullable=False))
Table("a", m2, Column("x", Integer, nullable=True))
ops = self._fixture(m1, m2, return_ops=True)
assert "autoincrement" not in ops.ops[0].ops[0].kw
def test_alter_column_autoincrement_pk_false(self):
m1 = MetaData()
m2 = MetaData()
Table(
"a",
m1,
Column("x", Integer, primary_key=True, autoincrement=False),
)
Table(
"a",
m2,
Column("x", BigInteger, primary_key=True, autoincrement=False),
)
ops = self._fixture(m1, m2, return_ops=True)
is_(ops.ops[0].ops[0].kw["autoincrement"], False)
def test_alter_column_autoincrement_pk_implicit_true(self):
m1 = MetaData()
m2 = MetaData()
Table("a", m1, Column("x", Integer, primary_key=True))
Table("a", m2, Column("x", BigInteger, primary_key=True))
ops = self._fixture(m1, m2, return_ops=True)
is_(ops.ops[0].ops[0].kw["autoincrement"], True)
def test_alter_column_autoincrement_pk_explicit_true(self):
m1 = MetaData()
m2 = MetaData()
Table(
"a", m1, Column("x", Integer, primary_key=True, autoincrement=True)
)
Table(
"a",
m2,
Column("x", BigInteger, primary_key=True, autoincrement=True),
)
ops = self._fixture(m1, m2, return_ops=True)
is_(ops.ops[0].ops[0].kw["autoincrement"], True)
def test_alter_column_autoincrement_nonpk_false(self):
m1 = MetaData()
m2 = MetaData()
Table(
"a",
m1,
Column("id", Integer, primary_key=True),
Column("x", Integer, autoincrement=False),
)
Table(
"a",
m2,
Column("id", Integer, primary_key=True),
Column("x", BigInteger, autoincrement=False),
)
ops = self._fixture(m1, m2, return_ops=True)
is_(ops.ops[0].ops[0].kw["autoincrement"], False)
def test_alter_column_autoincrement_nonpk_implicit_false(self):
m1 = MetaData()
m2 = MetaData()
Table(
"a",
m1,
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
Table(
"a",
m2,
Column("id", Integer, primary_key=True),
Column("x", BigInteger),
)
ops = self._fixture(m1, m2, return_ops=True)
assert "autoincrement" not in ops.ops[0].ops[0].kw
def test_alter_column_autoincrement_nonpk_explicit_true(self):
m1 = MetaData()
m2 = MetaData()
Table(
"a",
m1,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("x", Integer, autoincrement=True),
)
Table(
"a",
m2,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("x", BigInteger, autoincrement=True),
)
ops = self._fixture(m1, m2, return_ops=True)
is_(ops.ops[0].ops[0].kw["autoincrement"], True)
def test_alter_column_autoincrement_compositepk_false(self):
m1 = MetaData()
m2 = MetaData()
Table(
"a",
m1,
Column("id", Integer, primary_key=True),
Column("x", Integer, primary_key=True, autoincrement=False),
)
Table(
"a",
m2,
Column("id", Integer, primary_key=True),
Column("x", BigInteger, primary_key=True, autoincrement=False),
)
ops = self._fixture(m1, m2, return_ops=True)
is_(ops.ops[0].ops[0].kw["autoincrement"], False)
def test_alter_column_autoincrement_compositepk_implicit_false(self):
m1 = MetaData()
m2 = MetaData()
Table(
"a",
m1,
Column("id", Integer, primary_key=True),
Column("x", Integer, primary_key=True),
)
Table(
"a",
m2,
Column("id", Integer, primary_key=True),
Column("x", BigInteger, primary_key=True),
)
ops = self._fixture(m1, m2, return_ops=True)
assert "autoincrement" not in ops.ops[0].ops[0].kw
@config.requirements.autoincrement_on_composite_pk
def test_alter_column_autoincrement_compositepk_explicit_true(self):
m1 = MetaData()
m2 = MetaData()
Table(
"a",
m1,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("x", Integer, primary_key=True, autoincrement=True),
# on SQLA 1.0 and earlier, this being present
# trips the "add KEY for the primary key" so that the
# AUTO_INCREMENT keyword is accepted by MySQL. SQLA 1.1 and
# greater the columns are just reorganized.
mysql_engine="InnoDB",
)
Table(
"a",
m2,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("x", BigInteger, primary_key=True, autoincrement=True),
)
ops = self._fixture(m1, m2, return_ops=True)
is_(ops.ops[0].ops[0].kw["autoincrement"], True)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,226 @@
import sqlalchemy as sa
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import Table
from alembic.util import sqla_compat
from ._autogen_fixtures import AutogenFixtureTest
from ... import testing
from ...testing import config
from ...testing import eq_
from ...testing import is_true
from ...testing import TestBase
class AutogenerateIdentityTest(AutogenFixtureTest, TestBase):
__requires__ = ("identity_columns",)
__backend__ = True
def test_add_identity_column(self):
m1 = MetaData()
m2 = MetaData()
Table("user", m1, Column("other", sa.Text))
Table(
"user",
m2,
Column("other", sa.Text),
Column(
"id",
Integer,
sa.Identity(start=5, increment=7),
primary_key=True,
),
)
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "add_column")
eq_(diffs[0][2], "user")
eq_(diffs[0][3].name, "id")
i = diffs[0][3].identity
is_true(isinstance(i, sa.Identity))
eq_(i.start, 5)
eq_(i.increment, 7)
def test_remove_identity_column(self):
m1 = MetaData()
m2 = MetaData()
Table(
"user",
m1,
Column(
"id",
Integer,
sa.Identity(start=2, increment=3),
primary_key=True,
),
)
Table("user", m2)
diffs = self._fixture(m1, m2)
eq_(diffs[0][0], "remove_column")
eq_(diffs[0][2], "user")
c = diffs[0][3]
eq_(c.name, "id")
is_true(isinstance(c.identity, sa.Identity))
eq_(c.identity.start, 2)
eq_(c.identity.increment, 3)
def test_no_change_identity_column(self):
m1 = MetaData()
m2 = MetaData()
for m in (m1, m2):
id_ = sa.Identity(start=2)
Table("user", m, Column("id", Integer, id_))
diffs = self._fixture(m1, m2)
eq_(diffs, [])
def test_dialect_kwargs_changes(self):
m1 = MetaData()
m2 = MetaData()
if sqla_compat.identity_has_dialect_kwargs:
args = {"oracle_on_null": True, "oracle_order": True}
else:
args = {"on_null": True, "order": True}
Table("user", m1, Column("id", Integer, sa.Identity(start=2)))
id_ = sa.Identity(start=2, **args)
Table("user", m2, Column("id", Integer, id_))
diffs = self._fixture(m1, m2)
if config.db.name == "oracle":
is_true(len(diffs), 1)
eq_(diffs[0][0][0], "modify_default")
else:
eq_(diffs, [])
@testing.combinations(
(None, dict(start=2)),
(dict(start=2), None),
(dict(start=2), dict(start=2, increment=7)),
(dict(always=False), dict(always=True)),
(
dict(start=1, minvalue=0, maxvalue=100, cycle=True),
dict(start=1, minvalue=0, maxvalue=100, cycle=False),
),
(
dict(start=10, increment=3, maxvalue=9999),
dict(start=10, increment=1, maxvalue=3333),
),
)
@config.requirements.identity_columns_alter
def test_change_identity(self, before, after):
arg_before = (sa.Identity(**before),) if before else ()
arg_after = (sa.Identity(**after),) if after else ()
m1 = MetaData()
m2 = MetaData()
Table(
"user",
m1,
Column("id", Integer, *arg_before),
Column("other", sa.Text),
)
Table(
"user",
m2,
Column("id", Integer, *arg_after),
Column("other", sa.Text),
)
diffs = self._fixture(m1, m2)
eq_(len(diffs[0]), 1)
diffs = diffs[0][0]
eq_(diffs[0], "modify_default")
eq_(diffs[2], "user")
eq_(diffs[3], "id")
old = diffs[5]
new = diffs[6]
def check(kw, idt):
if kw:
is_true(isinstance(idt, sa.Identity))
for k, v in kw.items():
eq_(getattr(idt, k), v)
else:
is_true(idt in (None, False))
check(before, old)
check(after, new)
def test_add_identity_to_column(self):
m1 = MetaData()
m2 = MetaData()
Table(
"user",
m1,
Column("id", Integer),
Column("other", sa.Text),
)
Table(
"user",
m2,
Column("id", Integer, sa.Identity(start=2, maxvalue=1000)),
Column("other", sa.Text),
)
diffs = self._fixture(m1, m2)
eq_(len(diffs[0]), 1)
diffs = diffs[0][0]
eq_(diffs[0], "modify_default")
eq_(diffs[2], "user")
eq_(diffs[3], "id")
eq_(diffs[5], None)
added = diffs[6]
is_true(isinstance(added, sa.Identity))
eq_(added.start, 2)
eq_(added.maxvalue, 1000)
def test_remove_identity_from_column(self):
m1 = MetaData()
m2 = MetaData()
Table(
"user",
m1,
Column("id", Integer, sa.Identity(start=2, maxvalue=1000)),
Column("other", sa.Text),
)
Table(
"user",
m2,
Column("id", Integer),
Column("other", sa.Text),
)
diffs = self._fixture(m1, m2)
eq_(len(diffs[0]), 1)
diffs = diffs[0][0]
eq_(diffs[0], "modify_default")
eq_(diffs[2], "user")
eq_(diffs[3], "id")
eq_(diffs[6], None)
removed = diffs[5]
is_true(isinstance(removed, sa.Identity))

View File

@@ -0,0 +1,364 @@
import io
from ...migration import MigrationContext
from ...testing import assert_raises
from ...testing import config
from ...testing import eq_
from ...testing import is_
from ...testing import is_false
from ...testing import is_not_
from ...testing import is_true
from ...testing import ne_
from ...testing.fixtures import TestBase
class MigrationTransactionTest(TestBase):
__backend__ = True
conn = None
def _fixture(self, opts):
self.conn = conn = config.db.connect()
if opts.get("as_sql", False):
self.context = MigrationContext.configure(
dialect=conn.dialect, opts=opts
)
self.context.output_buffer = (
self.context.impl.output_buffer
) = io.StringIO()
else:
self.context = MigrationContext.configure(
connection=conn, opts=opts
)
return self.context
def teardown_method(self):
if self.conn:
self.conn.close()
def test_proxy_transaction_rollback(self):
context = self._fixture(
{"transaction_per_migration": True, "transactional_ddl": True}
)
is_false(self.conn.in_transaction())
proxy = context.begin_transaction(_per_migration=True)
is_true(self.conn.in_transaction())
proxy.rollback()
is_false(self.conn.in_transaction())
def test_proxy_transaction_commit(self):
context = self._fixture(
{"transaction_per_migration": True, "transactional_ddl": True}
)
proxy = context.begin_transaction(_per_migration=True)
is_true(self.conn.in_transaction())
proxy.commit()
is_false(self.conn.in_transaction())
def test_proxy_transaction_contextmanager_commit(self):
context = self._fixture(
{"transaction_per_migration": True, "transactional_ddl": True}
)
proxy = context.begin_transaction(_per_migration=True)
is_true(self.conn.in_transaction())
with proxy:
pass
is_false(self.conn.in_transaction())
def test_proxy_transaction_contextmanager_rollback(self):
context = self._fixture(
{"transaction_per_migration": True, "transactional_ddl": True}
)
proxy = context.begin_transaction(_per_migration=True)
is_true(self.conn.in_transaction())
def go():
with proxy:
raise Exception("hi")
assert_raises(Exception, go)
is_false(self.conn.in_transaction())
def test_proxy_transaction_contextmanager_explicit_rollback(self):
context = self._fixture(
{"transaction_per_migration": True, "transactional_ddl": True}
)
proxy = context.begin_transaction(_per_migration=True)
is_true(self.conn.in_transaction())
with proxy:
is_true(self.conn.in_transaction())
proxy.rollback()
is_false(self.conn.in_transaction())
is_false(self.conn.in_transaction())
def test_proxy_transaction_contextmanager_explicit_commit(self):
context = self._fixture(
{"transaction_per_migration": True, "transactional_ddl": True}
)
proxy = context.begin_transaction(_per_migration=True)
is_true(self.conn.in_transaction())
with proxy:
is_true(self.conn.in_transaction())
proxy.commit()
is_false(self.conn.in_transaction())
is_false(self.conn.in_transaction())
def test_transaction_per_migration_transactional_ddl(self):
context = self._fixture(
{"transaction_per_migration": True, "transactional_ddl": True}
)
is_false(self.conn.in_transaction())
with context.begin_transaction():
is_false(self.conn.in_transaction())
with context.begin_transaction(_per_migration=True):
is_true(self.conn.in_transaction())
is_false(self.conn.in_transaction())
is_false(self.conn.in_transaction())
def test_transaction_per_migration_non_transactional_ddl(self):
context = self._fixture(
{"transaction_per_migration": True, "transactional_ddl": False}
)
is_false(self.conn.in_transaction())
with context.begin_transaction():
is_false(self.conn.in_transaction())
with context.begin_transaction(_per_migration=True):
is_true(self.conn.in_transaction())
is_false(self.conn.in_transaction())
is_false(self.conn.in_transaction())
def test_transaction_per_all_transactional_ddl(self):
context = self._fixture({"transactional_ddl": True})
is_false(self.conn.in_transaction())
with context.begin_transaction():
is_true(self.conn.in_transaction())
with context.begin_transaction(_per_migration=True):
is_true(self.conn.in_transaction())
is_true(self.conn.in_transaction())
is_false(self.conn.in_transaction())
def test_transaction_per_all_non_transactional_ddl(self):
context = self._fixture({"transactional_ddl": False})
is_false(self.conn.in_transaction())
with context.begin_transaction():
is_false(self.conn.in_transaction())
with context.begin_transaction(_per_migration=True):
is_true(self.conn.in_transaction())
is_false(self.conn.in_transaction())
is_false(self.conn.in_transaction())
def test_transaction_per_all_sqlmode(self):
context = self._fixture({"as_sql": True})
context.execute("step 1")
with context.begin_transaction():
context.execute("step 2")
with context.begin_transaction(_per_migration=True):
context.execute("step 3")
context.execute("step 4")
context.execute("step 5")
if context.impl.transactional_ddl:
self._assert_impl_steps(
"step 1",
"BEGIN",
"step 2",
"step 3",
"step 4",
"COMMIT",
"step 5",
)
else:
self._assert_impl_steps(
"step 1", "step 2", "step 3", "step 4", "step 5"
)
def test_transaction_per_migration_sqlmode(self):
context = self._fixture(
{"as_sql": True, "transaction_per_migration": True}
)
context.execute("step 1")
with context.begin_transaction():
context.execute("step 2")
with context.begin_transaction(_per_migration=True):
context.execute("step 3")
context.execute("step 4")
context.execute("step 5")
if context.impl.transactional_ddl:
self._assert_impl_steps(
"step 1",
"step 2",
"BEGIN",
"step 3",
"COMMIT",
"step 4",
"step 5",
)
else:
self._assert_impl_steps(
"step 1", "step 2", "step 3", "step 4", "step 5"
)
@config.requirements.autocommit_isolation
def test_autocommit_block(self):
context = self._fixture({"transaction_per_migration": True})
is_false(self.conn.in_transaction())
with context.begin_transaction():
is_false(self.conn.in_transaction())
with context.begin_transaction(_per_migration=True):
is_true(self.conn.in_transaction())
with context.autocommit_block():
# in 1.x, self.conn is separate due to the
# execution_options call. however for future they are the
# same connection and there is a "transaction" block
# despite autocommit
if self.is_sqlalchemy_future:
is_(context.connection, self.conn)
else:
is_not_(context.connection, self.conn)
is_false(self.conn.in_transaction())
eq_(
context.connection._execution_options[
"isolation_level"
],
"AUTOCOMMIT",
)
ne_(
context.connection._execution_options.get(
"isolation_level", None
),
"AUTOCOMMIT",
)
is_true(self.conn.in_transaction())
is_false(self.conn.in_transaction())
is_false(self.conn.in_transaction())
@config.requirements.autocommit_isolation
def test_autocommit_block_no_transaction(self):
context = self._fixture({"transaction_per_migration": True})
is_false(self.conn.in_transaction())
with context.autocommit_block():
is_true(context.connection.in_transaction())
# in 1.x, self.conn is separate due to the execution_options
# call. however for future they are the same connection and there
# is a "transaction" block despite autocommit
if self.is_sqlalchemy_future:
is_(context.connection, self.conn)
else:
is_not_(context.connection, self.conn)
is_false(self.conn.in_transaction())
eq_(
context.connection._execution_options["isolation_level"],
"AUTOCOMMIT",
)
ne_(
context.connection._execution_options.get("isolation_level", None),
"AUTOCOMMIT",
)
is_false(self.conn.in_transaction())
def test_autocommit_block_transactional_ddl_sqlmode(self):
context = self._fixture(
{
"transaction_per_migration": True,
"transactional_ddl": True,
"as_sql": True,
}
)
with context.begin_transaction():
context.execute("step 1")
with context.begin_transaction(_per_migration=True):
context.execute("step 2")
with context.autocommit_block():
context.execute("step 3")
context.execute("step 4")
context.execute("step 5")
self._assert_impl_steps(
"step 1",
"BEGIN",
"step 2",
"COMMIT",
"step 3",
"BEGIN",
"step 4",
"COMMIT",
"step 5",
)
def test_autocommit_block_nontransactional_ddl_sqlmode(self):
context = self._fixture(
{
"transaction_per_migration": True,
"transactional_ddl": False,
"as_sql": True,
}
)
with context.begin_transaction():
context.execute("step 1")
with context.begin_transaction(_per_migration=True):
context.execute("step 2")
with context.autocommit_block():
context.execute("step 3")
context.execute("step 4")
context.execute("step 5")
self._assert_impl_steps(
"step 1", "step 2", "step 3", "step 4", "step 5"
)
def _assert_impl_steps(self, *steps):
to_check = self.context.output_buffer.getvalue()
self.context.impl.output_buffer = buf = io.StringIO()
for step in steps:
if step == "BEGIN":
self.context.impl.emit_begin()
elif step == "COMMIT":
self.context.impl.emit_commit()
else:
self.context.impl._exec(step)
eq_(to_check, buf.getvalue())

View File

@@ -0,0 +1,42 @@
"""Test against the builders in the op.* module."""
from sqlalchemy import Column
from sqlalchemy import event
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy.sql import text
from ...testing.fixtures import AlterColRoundTripFixture
from ...testing.fixtures import TestBase
@event.listens_for(Table, "after_parent_attach")
def _add_cols(table, metadata):
if table.name == "tbl_with_auto_appended_column":
table.append_column(Column("bat", Integer))
class BackendAlterColumnTest(AlterColRoundTripFixture, TestBase):
__backend__ = True
def test_rename_column(self):
self._run_alter_col({}, {"name": "newname"})
def test_modify_type_int_str(self):
self._run_alter_col({"type": Integer()}, {"type": String(50)})
def test_add_server_default_int(self):
self._run_alter_col({"type": Integer}, {"server_default": text("5")})
def test_modify_server_default_int(self):
self._run_alter_col(
{"type": Integer, "server_default": text("2")},
{"server_default": text("5")},
)
def test_modify_nullable_to_non(self):
self._run_alter_col({}, {"nullable": False})
def test_modify_non_nullable_to_nullable(self):
self._run_alter_col({"nullable": False}, {"nullable": True})

View File

@@ -0,0 +1,126 @@
# testing/util.py
# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
import types
from typing import Union
from sqlalchemy.util import inspect_getfullargspec
from ..util import sqla_2
def flag_combinations(*combinations):
"""A facade around @testing.combinations() oriented towards boolean
keyword-based arguments.
Basically generates a nice looking identifier based on the keywords
and also sets up the argument names.
E.g.::
@testing.flag_combinations(
dict(lazy=False, passive=False),
dict(lazy=True, passive=False),
dict(lazy=False, passive=True),
dict(lazy=False, passive=True, raiseload=True),
)
would result in::
@testing.combinations(
('', False, False, False),
('lazy', True, False, False),
('lazy_passive', True, True, False),
('lazy_passive', True, True, True),
id_='iaaa',
argnames='lazy,passive,raiseload'
)
"""
from sqlalchemy.testing import config
keys = set()
for d in combinations:
keys.update(d)
keys = sorted(keys)
return config.combinations(
*[
("_".join(k for k in keys if d.get(k, False)),)
+ tuple(d.get(k, False) for k in keys)
for d in combinations
],
id_="i" + ("a" * len(keys)),
argnames=",".join(keys),
)
def resolve_lambda(__fn, **kw):
"""Given a no-arg lambda and a namespace, return a new lambda that
has all the values filled in.
This is used so that we can have module-level fixtures that
refer to instance-level variables using lambdas.
"""
pos_args = inspect_getfullargspec(__fn)[0]
pass_pos_args = {arg: kw.pop(arg) for arg in pos_args}
glb = dict(__fn.__globals__)
glb.update(kw)
new_fn = types.FunctionType(__fn.__code__, glb)
return new_fn(**pass_pos_args)
def metadata_fixture(ddl="function"):
"""Provide MetaData for a pytest fixture."""
from sqlalchemy.testing import config
from . import fixture_functions
def decorate(fn):
def run_ddl(self):
from sqlalchemy import schema
metadata = self.metadata = schema.MetaData()
try:
result = fn(self, metadata)
metadata.create_all(config.db)
# TODO:
# somehow get a per-function dml erase fixture here
yield result
finally:
metadata.drop_all(config.db)
return fixture_functions.fixture(scope=ddl)(run_ddl)
return decorate
def _safe_int(value: str) -> Union[int, str]:
try:
return int(value)
except:
return value
def testing_engine(url=None, options=None, future=False):
from sqlalchemy.testing import config
from sqlalchemy.testing.engines import testing_engine
if not future:
future = getattr(config._current.options, "future_engine", False)
if not sqla_2:
kw = {"future": future} if future else {}
else:
kw = {}
return testing_engine(url, options, **kw)

View File

@@ -0,0 +1,40 @@
# testing/warnings.py
# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import warnings
from sqlalchemy import exc as sa_exc
from ..util import sqla_14
def setup_filters():
"""Set global warning behavior for the test suite."""
warnings.resetwarnings()
warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning)
warnings.filterwarnings("error", category=sa_exc.SAWarning)
# some selected deprecations...
warnings.filterwarnings("error", category=DeprecationWarning)
if not sqla_14:
# 1.3 uses pkg_resources in PluginLoader
warnings.filterwarnings(
"ignore",
"pkg_resources is deprecated as an API",
DeprecationWarning,
)
try:
import pytest
except ImportError:
pass
else:
warnings.filterwarnings(
"once", category=pytest.PytestDeprecationWarning
)

View File

@@ -0,0 +1,35 @@
from .editor import open_in_editor as open_in_editor
from .exc import AutogenerateDiffsDetected as AutogenerateDiffsDetected
from .exc import CommandError as CommandError
from .langhelpers import _with_legacy_names as _with_legacy_names
from .langhelpers import asbool as asbool
from .langhelpers import dedupe_tuple as dedupe_tuple
from .langhelpers import Dispatcher as Dispatcher
from .langhelpers import EMPTY_DICT as EMPTY_DICT
from .langhelpers import immutabledict as immutabledict
from .langhelpers import memoized_property as memoized_property
from .langhelpers import ModuleClsProxy as ModuleClsProxy
from .langhelpers import not_none as not_none
from .langhelpers import rev_id as rev_id
from .langhelpers import to_list as to_list
from .langhelpers import to_tuple as to_tuple
from .langhelpers import unique_list as unique_list
from .messaging import err as err
from .messaging import format_as_comma as format_as_comma
from .messaging import msg as msg
from .messaging import obfuscate_url_pw as obfuscate_url_pw
from .messaging import status as status
from .messaging import warn as warn
from .messaging import write_outstream as write_outstream
from .pyfiles import coerce_resource_to_filename as coerce_resource_to_filename
from .pyfiles import load_python_file as load_python_file
from .pyfiles import pyc_file_from_path as pyc_file_from_path
from .pyfiles import template_to_file as template_to_file
from .sqla_compat import has_computed as has_computed
from .sqla_compat import sqla_13 as sqla_13
from .sqla_compat import sqla_14 as sqla_14
from .sqla_compat import sqla_2 as sqla_2
if not sqla_13:
raise CommandError("SQLAlchemy 1.3.0 or greater is required.")

View File

@@ -0,0 +1,89 @@
# mypy: no-warn-unused-ignores
from __future__ import annotations
from configparser import ConfigParser
import io
import os
import sys
import typing
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Union
if True:
# zimports hack for too-long names
from sqlalchemy.util import ( # noqa: F401
inspect_getfullargspec as inspect_getfullargspec,
)
from sqlalchemy.util.compat import ( # noqa: F401
inspect_formatargspec as inspect_formatargspec,
)
is_posix = os.name == "posix"
py311 = sys.version_info >= (3, 11)
py310 = sys.version_info >= (3, 10)
py39 = sys.version_info >= (3, 9)
# produce a wrapper that allows encoded text to stream
# into a given buffer, but doesn't close it.
# not sure of a more idiomatic approach to this.
class EncodedIO(io.TextIOWrapper):
def close(self) -> None:
pass
if py39:
from importlib import resources as _resources
importlib_resources = _resources
from importlib import metadata as _metadata
importlib_metadata = _metadata
from importlib.metadata import EntryPoint as EntryPoint
else:
import importlib_resources # type:ignore # noqa
import importlib_metadata # type:ignore # noqa
from importlib_metadata import EntryPoint # type:ignore # noqa
def importlib_metadata_get(group: str) -> Sequence[EntryPoint]:
ep = importlib_metadata.entry_points()
if hasattr(ep, "select"):
return ep.select(group=group)
else:
return ep.get(group, ()) # type: ignore
def formatannotation_fwdref(
annotation: Any, base_module: Optional[Any] = None
) -> str:
"""vendored from python 3.7"""
# copied over _formatannotation from sqlalchemy 2.0
if isinstance(annotation, str):
return annotation
if getattr(annotation, "__module__", None) == "typing":
return repr(annotation).replace("typing.", "").replace("~", "")
if isinstance(annotation, type):
if annotation.__module__ in ("builtins", base_module):
return repr(annotation.__qualname__)
return annotation.__module__ + "." + annotation.__qualname__
elif isinstance(annotation, typing.TypeVar):
return repr(annotation).replace("~", "")
return repr(annotation).replace("~", "")
def read_config_parser(
file_config: ConfigParser,
file_argument: Sequence[Union[str, os.PathLike[str]]],
) -> List[str]:
if py310:
return file_config.read(file_argument, encoding="locale")
else:
return file_config.read(file_argument)

View File

@@ -0,0 +1,81 @@
from __future__ import annotations
import os
from os.path import exists
from os.path import join
from os.path import splitext
from subprocess import check_call
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
from .compat import is_posix
from .exc import CommandError
def open_in_editor(
filename: str, environ: Optional[Dict[str, str]] = None
) -> None:
"""
Opens the given file in a text editor. If the environment variable
``EDITOR`` is set, this is taken as preference.
Otherwise, a list of commonly installed editors is tried.
If no editor matches, an :py:exc:`OSError` is raised.
:param filename: The filename to open. Will be passed verbatim to the
editor command.
:param environ: An optional drop-in replacement for ``os.environ``. Used
mainly for testing.
"""
env = os.environ if environ is None else environ
try:
editor = _find_editor(env)
check_call([editor, filename])
except Exception as exc:
raise CommandError("Error executing editor (%s)" % (exc,)) from exc
def _find_editor(environ: Mapping[str, str]) -> str:
candidates = _default_editors()
for i, var in enumerate(("EDITOR", "VISUAL")):
if var in environ:
user_choice = environ[var]
if exists(user_choice):
return user_choice
if os.sep not in user_choice:
candidates.insert(i, user_choice)
for candidate in candidates:
path = _find_executable(candidate, environ)
if path is not None:
return path
raise OSError(
"No suitable editor found. Please set the "
'"EDITOR" or "VISUAL" environment variables'
)
def _find_executable(
candidate: str, environ: Mapping[str, str]
) -> Optional[str]:
# Assuming this is on the PATH, we need to determine it's absolute
# location. Otherwise, ``check_call`` will fail
if not is_posix and splitext(candidate)[1] != ".exe":
candidate += ".exe"
for path in environ.get("PATH", "").split(os.pathsep):
value = join(path, candidate)
if exists(value):
return value
return None
def _default_editors() -> List[str]:
# Look for an editor. Prefer the user's choice by env-var, fall back to
# most commonly installed editor (nano/vim)
if is_posix:
return ["sensible-editor", "editor", "nano", "vim", "code"]
else:
return ["code.exe", "notepad++.exe", "notepad.exe"]

View File

@@ -0,0 +1,6 @@
class CommandError(Exception):
pass
class AutogenerateDiffsDetected(CommandError):
pass

View File

@@ -0,0 +1,335 @@
from __future__ import annotations
import collections
from collections.abc import Iterable
import textwrap
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import uuid
import warnings
from sqlalchemy.util import asbool as asbool # noqa: F401
from sqlalchemy.util import immutabledict as immutabledict # noqa: F401
from sqlalchemy.util import to_list as to_list # noqa: F401
from sqlalchemy.util import unique_list as unique_list
from .compat import inspect_getfullargspec
if True:
# zimports workaround :(
from sqlalchemy.util import ( # noqa: F401
memoized_property as memoized_property,
)
EMPTY_DICT: Mapping[Any, Any] = immutabledict()
_T = TypeVar("_T", bound=Any)
_C = TypeVar("_C", bound=Callable[..., Any])
class _ModuleClsMeta(type):
def __setattr__(cls, key: str, value: Callable[..., Any]) -> None:
super().__setattr__(key, value)
cls._update_module_proxies(key) # type: ignore
class ModuleClsProxy(metaclass=_ModuleClsMeta):
"""Create module level proxy functions for the
methods on a given class.
The functions will have a compatible signature
as the methods.
"""
_setups: Dict[
Type[Any],
Tuple[
Set[str],
List[Tuple[MutableMapping[str, Any], MutableMapping[str, Any]]],
],
] = collections.defaultdict(lambda: (set(), []))
@classmethod
def _update_module_proxies(cls, name: str) -> None:
attr_names, modules = cls._setups[cls]
for globals_, locals_ in modules:
cls._add_proxied_attribute(name, globals_, locals_, attr_names)
def _install_proxy(self) -> None:
attr_names, modules = self._setups[self.__class__]
for globals_, locals_ in modules:
globals_["_proxy"] = self
for attr_name in attr_names:
globals_[attr_name] = getattr(self, attr_name)
def _remove_proxy(self) -> None:
attr_names, modules = self._setups[self.__class__]
for globals_, locals_ in modules:
globals_["_proxy"] = None
for attr_name in attr_names:
del globals_[attr_name]
@classmethod
def create_module_class_proxy(
cls,
globals_: MutableMapping[str, Any],
locals_: MutableMapping[str, Any],
) -> None:
attr_names, modules = cls._setups[cls]
modules.append((globals_, locals_))
cls._setup_proxy(globals_, locals_, attr_names)
@classmethod
def _setup_proxy(
cls,
globals_: MutableMapping[str, Any],
locals_: MutableMapping[str, Any],
attr_names: Set[str],
) -> None:
for methname in dir(cls):
cls._add_proxied_attribute(methname, globals_, locals_, attr_names)
@classmethod
def _add_proxied_attribute(
cls,
methname: str,
globals_: MutableMapping[str, Any],
locals_: MutableMapping[str, Any],
attr_names: Set[str],
) -> None:
if not methname.startswith("_"):
meth = getattr(cls, methname)
if callable(meth):
locals_[methname] = cls._create_method_proxy(
methname, globals_, locals_
)
else:
attr_names.add(methname)
@classmethod
def _create_method_proxy(
cls,
name: str,
globals_: MutableMapping[str, Any],
locals_: MutableMapping[str, Any],
) -> Callable[..., Any]:
fn = getattr(cls, name)
def _name_error(name: str, from_: Exception) -> NoReturn:
raise NameError(
"Can't invoke function '%s', as the proxy object has "
"not yet been "
"established for the Alembic '%s' class. "
"Try placing this code inside a callable."
% (name, cls.__name__)
) from from_
globals_["_name_error"] = _name_error
translations = getattr(fn, "_legacy_translations", [])
if translations:
spec = inspect_getfullargspec(fn)
if spec[0] and spec[0][0] == "self":
spec[0].pop(0)
outer_args = inner_args = "*args, **kw"
translate_str = "args, kw = _translate(%r, %r, %r, args, kw)" % (
fn.__name__,
tuple(spec),
translations,
)
def translate(
fn_name: str, spec: Any, translations: Any, args: Any, kw: Any
) -> Any:
return_kw = {}
return_args = []
for oldname, newname in translations:
if oldname in kw:
warnings.warn(
"Argument %r is now named %r "
"for method %s()." % (oldname, newname, fn_name)
)
return_kw[newname] = kw.pop(oldname)
return_kw.update(kw)
args = list(args)
if spec[3]:
pos_only = spec[0][: -len(spec[3])]
else:
pos_only = spec[0]
for arg in pos_only:
if arg not in return_kw:
try:
return_args.append(args.pop(0))
except IndexError:
raise TypeError(
"missing required positional argument: %s"
% arg
)
return_args.extend(args)
return return_args, return_kw
globals_["_translate"] = translate
else:
outer_args = "*args, **kw"
inner_args = "*args, **kw"
translate_str = ""
func_text = textwrap.dedent(
"""\
def %(name)s(%(args)s):
%(doc)r
%(translate)s
try:
p = _proxy
except NameError as ne:
_name_error('%(name)s', ne)
return _proxy.%(name)s(%(apply_kw)s)
e
"""
% {
"name": name,
"translate": translate_str,
"args": outer_args,
"apply_kw": inner_args,
"doc": fn.__doc__,
}
)
lcl: MutableMapping[str, Any] = {}
exec(func_text, cast("Dict[str, Any]", globals_), lcl)
return cast("Callable[..., Any]", lcl[name])
def _with_legacy_names(translations: Any) -> Any:
def decorate(fn: _C) -> _C:
fn._legacy_translations = translations # type: ignore[attr-defined]
return fn
return decorate
def rev_id() -> str:
return uuid.uuid4().hex[-12:]
@overload
def to_tuple(x: Any, default: Tuple[Any, ...]) -> Tuple[Any, ...]:
...
@overload
def to_tuple(x: None, default: Optional[_T] = ...) -> _T:
...
@overload
def to_tuple(
x: Any, default: Optional[Tuple[Any, ...]] = None
) -> Tuple[Any, ...]:
...
def to_tuple(
x: Any, default: Optional[Tuple[Any, ...]] = None
) -> Optional[Tuple[Any, ...]]:
if x is None:
return default
elif isinstance(x, str):
return (x,)
elif isinstance(x, Iterable):
return tuple(x)
else:
return (x,)
def dedupe_tuple(tup: Tuple[str, ...]) -> Tuple[str, ...]:
return tuple(unique_list(tup))
class Dispatcher:
def __init__(self, uselist: bool = False) -> None:
self._registry: Dict[Tuple[Any, ...], Any] = {}
self.uselist = uselist
def dispatch_for(
self, target: Any, qualifier: str = "default"
) -> Callable[[_C], _C]:
def decorate(fn: _C) -> _C:
if self.uselist:
self._registry.setdefault((target, qualifier), []).append(fn)
else:
assert (target, qualifier) not in self._registry
self._registry[(target, qualifier)] = fn
return fn
return decorate
def dispatch(self, obj: Any, qualifier: str = "default") -> Any:
if isinstance(obj, str):
targets: Sequence[Any] = [obj]
elif isinstance(obj, type):
targets = obj.__mro__
else:
targets = type(obj).__mro__
for spcls in targets:
if qualifier != "default" and (spcls, qualifier) in self._registry:
return self._fn_or_list(self._registry[(spcls, qualifier)])
elif (spcls, "default") in self._registry:
return self._fn_or_list(self._registry[(spcls, "default")])
else:
raise ValueError("no dispatch function for object: %s" % obj)
def _fn_or_list(
self, fn_or_list: Union[List[Callable[..., Any]], Callable[..., Any]]
) -> Callable[..., Any]:
if self.uselist:
def go(*arg: Any, **kw: Any) -> None:
if TYPE_CHECKING:
assert isinstance(fn_or_list, Sequence)
for fn in fn_or_list:
fn(*arg, **kw)
return go
else:
return fn_or_list # type: ignore
def branch(self) -> Dispatcher:
"""Return a copy of this dispatcher that is independently
writable."""
d = Dispatcher()
if self.uselist:
d._registry.update(
(k, [fn for fn in self._registry[k]]) for k in self._registry
)
else:
d._registry.update(self._registry)
return d
def not_none(value: Optional[_T]) -> _T:
assert value is not None
return value

View File

@@ -0,0 +1,115 @@
from __future__ import annotations
from collections.abc import Iterable
from contextlib import contextmanager
import logging
import sys
import textwrap
from typing import Iterator
from typing import Optional
from typing import TextIO
from typing import Union
import warnings
from sqlalchemy.engine import url
from . import sqla_compat
log = logging.getLogger(__name__)
# disable "no handler found" errors
logging.getLogger("alembic").addHandler(logging.NullHandler())
try:
import fcntl
import termios
import struct
ioctl = fcntl.ioctl(0, termios.TIOCGWINSZ, struct.pack("HHHH", 0, 0, 0, 0))
_h, TERMWIDTH, _hp, _wp = struct.unpack("HHHH", ioctl)
if TERMWIDTH <= 0: # can occur if running in emacs pseudo-tty
TERMWIDTH = None
except (ImportError, OSError):
TERMWIDTH = None
def write_outstream(
stream: TextIO, *text: Union[str, bytes], quiet: bool = False
) -> None:
if quiet:
return
encoding = getattr(stream, "encoding", "ascii") or "ascii"
for t in text:
if not isinstance(t, bytes):
t = t.encode(encoding, "replace")
t = t.decode(encoding)
try:
stream.write(t)
except OSError:
# suppress "broken pipe" errors.
# no known way to handle this on Python 3 however
# as the exception is "ignored" (noisily) in TextIOWrapper.
break
@contextmanager
def status(
status_msg: str, newline: bool = False, quiet: bool = False
) -> Iterator[None]:
msg(status_msg + " ...", newline, flush=True, quiet=quiet)
try:
yield
except:
if not quiet:
write_outstream(sys.stdout, " FAILED\n")
raise
else:
if not quiet:
write_outstream(sys.stdout, " done\n")
def err(message: str, quiet: bool = False) -> None:
log.error(message)
msg(f"FAILED: {message}", quiet=quiet)
sys.exit(-1)
def obfuscate_url_pw(input_url: str) -> str:
u = url.make_url(input_url)
return sqla_compat.url_render_as_string(u, hide_password=True) # type: ignore # noqa: E501
def warn(msg: str, stacklevel: int = 2) -> None:
warnings.warn(msg, UserWarning, stacklevel=stacklevel)
def msg(
msg: str, newline: bool = True, flush: bool = False, quiet: bool = False
) -> None:
if quiet:
return
if TERMWIDTH is None:
write_outstream(sys.stdout, msg)
if newline:
write_outstream(sys.stdout, "\n")
else:
# left indent output lines
lines = textwrap.wrap(msg, TERMWIDTH)
if len(lines) > 1:
for line in lines[0:-1]:
write_outstream(sys.stdout, " ", line, "\n")
write_outstream(sys.stdout, " ", lines[-1], ("\n" if newline else ""))
if flush:
sys.stdout.flush()
def format_as_comma(value: Optional[Union[str, Iterable[str]]]) -> str:
if value is None:
return ""
elif isinstance(value, str):
return value
elif isinstance(value, Iterable):
return ", ".join(value)
else:
raise ValueError("Don't know how to comma-format %r" % value)

View File

@@ -0,0 +1,114 @@
from __future__ import annotations
import atexit
from contextlib import ExitStack
import importlib
import importlib.machinery
import importlib.util
import os
import re
import tempfile
from types import ModuleType
from typing import Any
from typing import Optional
from mako import exceptions
from mako.template import Template
from . import compat
from .exc import CommandError
def template_to_file(
template_file: str, dest: str, output_encoding: str, **kw: Any
) -> None:
template = Template(filename=template_file)
try:
output = template.render_unicode(**kw).encode(output_encoding)
except:
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as ntf:
ntf.write(
exceptions.text_error_template()
.render_unicode()
.encode(output_encoding)
)
fname = ntf.name
raise CommandError(
"Template rendering failed; see %s for a "
"template-oriented traceback." % fname
)
else:
with open(dest, "wb") as f:
f.write(output)
def coerce_resource_to_filename(fname: str) -> str:
"""Interpret a filename as either a filesystem location or as a package
resource.
Names that are non absolute paths and contain a colon
are interpreted as resources and coerced to a file location.
"""
if not os.path.isabs(fname) and ":" in fname:
tokens = fname.split(":")
# from https://importlib-resources.readthedocs.io/en/latest/migration.html#pkg-resources-resource-filename # noqa E501
file_manager = ExitStack()
atexit.register(file_manager.close)
ref = compat.importlib_resources.files(tokens[0])
for tok in tokens[1:]:
ref = ref / tok
fname = file_manager.enter_context( # type: ignore[assignment]
compat.importlib_resources.as_file(ref)
)
return fname
def pyc_file_from_path(path: str) -> Optional[str]:
"""Given a python source path, locate the .pyc."""
candidate = importlib.util.cache_from_source(path)
if os.path.exists(candidate):
return candidate
# even for pep3147, fall back to the old way of finding .pyc files,
# to support sourceless operation
filepath, ext = os.path.splitext(path)
for ext in importlib.machinery.BYTECODE_SUFFIXES:
if os.path.exists(filepath + ext):
return filepath + ext
else:
return None
def load_python_file(dir_: str, filename: str) -> ModuleType:
"""Load a file from the given path as a Python module."""
module_id = re.sub(r"\W", "_", filename)
path = os.path.join(dir_, filename)
_, ext = os.path.splitext(filename)
if ext == ".py":
if os.path.exists(path):
module = load_module_py(module_id, path)
else:
pyc_path = pyc_file_from_path(path)
if pyc_path is None:
raise ImportError("Can't find Python file %s" % path)
else:
module = load_module_py(module_id, pyc_path)
elif ext in (".pyc", ".pyo"):
module = load_module_py(module_id, path)
else:
assert False
return module
def load_module_py(module_id: str, path: str) -> ModuleType:
spec = importlib.util.spec_from_file_location(module_id, path)
assert spec
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module

View File

@@ -0,0 +1,665 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import contextlib
import re
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import Mapping
from typing import Optional
from typing import Protocol
from typing import Set
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from sqlalchemy import __version__
from sqlalchemy import inspect
from sqlalchemy import schema
from sqlalchemy import sql
from sqlalchemy import types as sqltypes
from sqlalchemy.engine import url
from sqlalchemy.schema import CheckConstraint
from sqlalchemy.schema import Column
from sqlalchemy.schema import ForeignKeyConstraint
from sqlalchemy.sql import visitors
from sqlalchemy.sql.base import DialectKWArgs
from sqlalchemy.sql.elements import BindParameter
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.sql.visitors import traverse
from typing_extensions import TypeGuard
if TYPE_CHECKING:
from sqlalchemy import ClauseElement
from sqlalchemy import Index
from sqlalchemy import Table
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from sqlalchemy.engine import Transaction
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.base import ColumnCollection
from sqlalchemy.sql.compiler import SQLCompiler
from sqlalchemy.sql.dml import Insert
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import SchemaItem
from sqlalchemy.sql.selectable import Select
from sqlalchemy.sql.selectable import TableClause
_CE = TypeVar("_CE", bound=Union["ColumnElement[Any]", "SchemaItem"])
class _CompilerProtocol(Protocol):
def __call__(self, element: Any, compiler: Any, **kw: Any) -> str:
...
def _safe_int(value: str) -> Union[int, str]:
try:
return int(value)
except:
return value
_vers = tuple(
[_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
)
sqla_13 = _vers >= (1, 3)
sqla_14 = _vers >= (1, 4)
# https://docs.sqlalchemy.org/en/latest/changelog/changelog_14.html#change-0c6e0cc67dfe6fac5164720e57ef307d
sqla_14_18 = _vers >= (1, 4, 18)
sqla_14_26 = _vers >= (1, 4, 26)
sqla_2 = _vers >= (2,)
sqlalchemy_version = __version__
try:
from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME # type: ignore[attr-defined] # noqa: E501
except ImportError:
from sqlalchemy.sql.elements import _NONE_NAME as _NONE_NAME # type: ignore # noqa: E501
class _Unsupported:
"Placeholder for unsupported SQLAlchemy classes"
if TYPE_CHECKING:
def compiles(
element: Type[ClauseElement], *dialects: str
) -> Callable[[_CompilerProtocol], _CompilerProtocol]:
...
else:
from sqlalchemy.ext.compiler import compiles
try:
from sqlalchemy import Computed as Computed
except ImportError:
if not TYPE_CHECKING:
class Computed(_Unsupported):
pass
has_computed = False
has_computed_reflection = False
else:
has_computed = True
has_computed_reflection = _vers >= (1, 3, 16)
try:
from sqlalchemy import Identity as Identity
except ImportError:
if not TYPE_CHECKING:
class Identity(_Unsupported):
pass
has_identity = False
else:
identity_has_dialect_kwargs = issubclass(Identity, DialectKWArgs)
def _get_identity_options_dict(
identity: Union[Identity, schema.Sequence, None],
dialect_kwargs: bool = False,
) -> Dict[str, Any]:
if identity is None:
return {}
elif identity_has_dialect_kwargs:
as_dict = identity._as_dict() # type: ignore
if dialect_kwargs:
assert isinstance(identity, DialectKWArgs)
as_dict.update(identity.dialect_kwargs)
else:
as_dict = {}
if isinstance(identity, Identity):
# always=None means something different than always=False
as_dict["always"] = identity.always
if identity.on_null is not None:
as_dict["on_null"] = identity.on_null
# attributes common to Identity and Sequence
attrs = (
"start",
"increment",
"minvalue",
"maxvalue",
"nominvalue",
"nomaxvalue",
"cycle",
"cache",
"order",
)
as_dict.update(
{
key: getattr(identity, key, None)
for key in attrs
if getattr(identity, key, None) is not None
}
)
return as_dict
has_identity = True
if sqla_2:
from sqlalchemy.sql.base import _NoneName
else:
from sqlalchemy.util import symbol as _NoneName # type: ignore[assignment]
_ConstraintName = Union[None, str, _NoneName]
_ConstraintNameDefined = Union[str, _NoneName]
def constraint_name_defined(
name: _ConstraintName,
) -> TypeGuard[_ConstraintNameDefined]:
return name is _NONE_NAME or isinstance(name, (str, _NoneName))
def constraint_name_string(
name: _ConstraintName,
) -> TypeGuard[str]:
return isinstance(name, str)
def constraint_name_or_none(
name: _ConstraintName,
) -> Optional[str]:
return name if constraint_name_string(name) else None
AUTOINCREMENT_DEFAULT = "auto"
@contextlib.contextmanager
def _ensure_scope_for_ddl(
connection: Optional[Connection],
) -> Iterator[None]:
try:
in_transaction = connection.in_transaction # type: ignore[union-attr]
except AttributeError:
# catch for MockConnection, None
in_transaction = None
pass
# yield outside the catch
if in_transaction is None:
yield
else:
if not in_transaction():
assert connection is not None
with connection.begin():
yield
else:
yield
def url_render_as_string(url, hide_password=True):
if sqla_14:
return url.render_as_string(hide_password=hide_password)
else:
return url.__to_string__(hide_password=hide_password)
def _safe_begin_connection_transaction(
connection: Connection,
) -> Transaction:
transaction = _get_connection_transaction(connection)
if transaction:
return transaction
else:
return connection.begin()
def _safe_commit_connection_transaction(
connection: Connection,
) -> None:
transaction = _get_connection_transaction(connection)
if transaction:
transaction.commit()
def _safe_rollback_connection_transaction(
connection: Connection,
) -> None:
transaction = _get_connection_transaction(connection)
if transaction:
transaction.rollback()
def _get_connection_in_transaction(connection: Optional[Connection]) -> bool:
try:
in_transaction = connection.in_transaction # type: ignore
except AttributeError:
# catch for MockConnection
return False
else:
return in_transaction()
def _idx_table_bound_expressions(idx: Index) -> Iterable[ColumnElement[Any]]:
return idx.expressions # type: ignore
def _copy(schema_item: _CE, **kw) -> _CE:
if hasattr(schema_item, "_copy"):
return schema_item._copy(**kw)
else:
return schema_item.copy(**kw) # type: ignore[union-attr]
def _get_connection_transaction(
connection: Connection,
) -> Optional[Transaction]:
if sqla_14:
return connection.get_transaction()
else:
r = connection._root # type: ignore[attr-defined]
return r._Connection__transaction
def _create_url(*arg, **kw) -> url.URL:
if hasattr(url.URL, "create"):
return url.URL.create(*arg, **kw)
else:
return url.URL(*arg, **kw)
def _connectable_has_table(
connectable: Connection, tablename: str, schemaname: Union[str, None]
) -> bool:
if sqla_14:
return inspect(connectable).has_table(tablename, schemaname)
else:
return connectable.dialect.has_table(
connectable, tablename, schemaname
)
def _exec_on_inspector(inspector, statement, **params):
if sqla_14:
with inspector._operation_context() as conn:
return conn.execute(statement, params)
else:
return inspector.bind.execute(statement, params)
def _nullability_might_be_unset(metadata_column):
if not sqla_14:
return metadata_column.nullable
else:
from sqlalchemy.sql import schema
return (
metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED
)
def _server_default_is_computed(*server_default) -> bool:
if not has_computed:
return False
else:
return any(isinstance(sd, Computed) for sd in server_default)
def _server_default_is_identity(*server_default) -> bool:
if not sqla_14:
return False
else:
return any(isinstance(sd, Identity) for sd in server_default)
def _table_for_constraint(constraint: Constraint) -> Table:
if isinstance(constraint, ForeignKeyConstraint):
table = constraint.parent
assert table is not None
return table # type: ignore[return-value]
else:
return constraint.table
def _columns_for_constraint(constraint):
if isinstance(constraint, ForeignKeyConstraint):
return [fk.parent for fk in constraint.elements]
elif isinstance(constraint, CheckConstraint):
return _find_columns(constraint.sqltext)
else:
return list(constraint.columns)
def _reflect_table(inspector: Inspector, table: Table) -> None:
if sqla_14:
return inspector.reflect_table(table, None)
else:
return inspector.reflecttable( # type: ignore[attr-defined]
table, None
)
def _resolve_for_variant(type_, dialect):
if _type_has_variants(type_):
base_type, mapping = _get_variant_mapping(type_)
return mapping.get(dialect.name, base_type)
else:
return type_
if hasattr(sqltypes.TypeEngine, "_variant_mapping"):
def _type_has_variants(type_):
return bool(type_._variant_mapping)
def _get_variant_mapping(type_):
return type_, type_._variant_mapping
else:
def _type_has_variants(type_):
return type(type_) is sqltypes.Variant
def _get_variant_mapping(type_):
return type_.impl, type_.mapping
def _fk_spec(constraint: ForeignKeyConstraint) -> Any:
if TYPE_CHECKING:
assert constraint.columns is not None
assert constraint.elements is not None
assert isinstance(constraint.parent, Table)
source_columns = [
constraint.columns[key].name for key in constraint.column_keys
]
source_table = constraint.parent.name
source_schema = constraint.parent.schema
target_schema = constraint.elements[0].column.table.schema
target_table = constraint.elements[0].column.table.name
target_columns = [element.column.name for element in constraint.elements]
ondelete = constraint.ondelete
onupdate = constraint.onupdate
deferrable = constraint.deferrable
initially = constraint.initially
return (
source_schema,
source_table,
source_columns,
target_schema,
target_table,
target_columns,
onupdate,
ondelete,
deferrable,
initially,
)
def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool:
spec = constraint.elements[0]._get_colspec()
tokens = spec.split(".")
tokens.pop(-1) # colname
tablekey = ".".join(tokens)
assert constraint.parent is not None
return tablekey == constraint.parent.key
def _is_type_bound(constraint: Constraint) -> bool:
# this deals with SQLAlchemy #3260, don't copy CHECK constraints
# that will be generated by the type.
# new feature added for #3260
return constraint._type_bound
def _find_columns(clause):
"""locate Column objects within the given expression."""
cols: Set[ColumnElement[Any]] = set()
traverse(clause, {}, {"column": cols.add})
return cols
def _remove_column_from_collection(
collection: ColumnCollection, column: Union[Column[Any], ColumnClause[Any]]
) -> None:
"""remove a column from a ColumnCollection."""
# workaround for older SQLAlchemy, remove the
# same object that's present
assert column.key is not None
to_remove = collection[column.key]
# SQLAlchemy 2.0 will use more ReadOnlyColumnCollection
# (renamed from ImmutableColumnCollection)
if hasattr(collection, "_immutable") or hasattr(collection, "_readonly"):
collection._parent.remove(to_remove)
else:
collection.remove(to_remove)
def _textual_index_column(
table: Table, text_: Union[str, TextClause, ColumnElement[Any]]
) -> Union[ColumnElement[Any], Column[Any]]:
"""a workaround for the Index construct's severe lack of flexibility"""
if isinstance(text_, str):
c = Column(text_, sqltypes.NULLTYPE)
table.append_column(c)
return c
elif isinstance(text_, TextClause):
return _textual_index_element(table, text_)
elif isinstance(text_, _textual_index_element):
return _textual_index_column(table, text_.text)
elif isinstance(text_, sql.ColumnElement):
return _copy_expression(text_, table)
else:
raise ValueError("String or text() construct expected")
def _copy_expression(expression: _CE, target_table: Table) -> _CE:
def replace(col):
if (
isinstance(col, Column)
and col.table is not None
and col.table is not target_table
):
if col.name in target_table.c:
return target_table.c[col.name]
else:
c = _copy(col)
target_table.append_column(c)
return c
else:
return None
return visitors.replacement_traverse( # type: ignore[call-overload]
expression, {}, replace
)
class _textual_index_element(sql.ColumnElement):
"""Wrap around a sqlalchemy text() construct in such a way that
we appear like a column-oriented SQL expression to an Index
construct.
The issue here is that currently the Postgresql dialect, the biggest
recipient of functional indexes, keys all the index expressions to
the corresponding column expressions when rendering CREATE INDEX,
so the Index we create here needs to have a .columns collection that
is the same length as the .expressions collection. Ultimately
SQLAlchemy should support text() expressions in indexes.
See SQLAlchemy issue 3174.
"""
__visit_name__ = "_textual_idx_element"
def __init__(self, table: Table, text: TextClause) -> None:
self.table = table
self.text = text
self.key = text.text
self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE)
table.append_column(self.fake_column)
def get_children(self):
return [self.fake_column]
@compiles(_textual_index_element)
def _render_textual_index_column(
element: _textual_index_element, compiler: SQLCompiler, **kw
) -> str:
return compiler.process(element.text, **kw)
class _literal_bindparam(BindParameter):
pass
@compiles(_literal_bindparam)
def _render_literal_bindparam(
element: _literal_bindparam, compiler: SQLCompiler, **kw
) -> str:
return compiler.render_literal_bindparam(element, **kw)
def _column_kwargs(col: Column) -> Mapping:
if sqla_13:
return col.kwargs
else:
return {}
def _get_constraint_final_name(
constraint: Union[Index, Constraint], dialect: Optional[Dialect]
) -> Optional[str]:
if constraint.name is None:
return None
assert dialect is not None
if sqla_14:
# for SQLAlchemy 1.4 we would like to have the option to expand
# the use of "deferred" names for constraints as well as to have
# some flexibility with "None" name and similar; make use of new
# SQLAlchemy API to return what would be the final compiled form of
# the name for this dialect.
return dialect.identifier_preparer.format_constraint(
constraint, _alembic_quote=False
)
else:
# prior to SQLAlchemy 1.4, work around quoting logic to get at the
# final compiled name without quotes.
if hasattr(constraint.name, "quote"):
# might be quoted_name, might be truncated_name, keep it the
# same
quoted_name_cls: type = type(constraint.name)
else:
quoted_name_cls = quoted_name
new_name = quoted_name_cls(str(constraint.name), quote=False)
constraint = constraint.__class__(name=new_name)
if isinstance(constraint, schema.Index):
# name should not be quoted.
d = dialect.ddl_compiler(dialect, None) # type: ignore[arg-type]
return d._prepared_index_name(constraint)
else:
# name should not be quoted.
return dialect.identifier_preparer.format_constraint(constraint)
def _constraint_is_named(
constraint: Union[Constraint, Index], dialect: Optional[Dialect]
) -> bool:
if sqla_14:
if constraint.name is None:
return False
assert dialect is not None
name = dialect.identifier_preparer.format_constraint(
constraint, _alembic_quote=False
)
return name is not None
else:
return constraint.name is not None
def _is_mariadb(mysql_dialect: Dialect) -> bool:
if sqla_14:
return mysql_dialect.is_mariadb # type: ignore[attr-defined]
else:
return bool(
mysql_dialect.server_version_info
and mysql_dialect._is_mariadb # type: ignore[attr-defined]
)
def _mariadb_normalized_version_info(mysql_dialect):
return mysql_dialect._mariadb_normalized_version_info
def _insert_inline(table: Union[TableClause, Table]) -> Insert:
if sqla_14:
return table.insert().inline()
else:
return table.insert(inline=True) # type: ignore[call-arg]
if sqla_14:
from sqlalchemy import create_mock_engine
# weird mypy workaround
from sqlalchemy import select as _sa_select
_select = _sa_select
else:
from sqlalchemy import create_engine
def create_mock_engine(url, executor, **kw): # type: ignore[misc]
return create_engine(
"postgresql://", strategy="mock", executor=executor
)
def _select(*columns, **kw) -> Select:
return sql.select(list(columns), **kw) # type: ignore[call-overload]
def is_expression_index(index: Index) -> bool:
for expr in index.expressions:
if is_expression(expr):
return True
return False
def is_expression(expr: Any) -> bool:
while isinstance(expr, UnaryExpression):
expr = expr.element
if not isinstance(expr, ColumnClause) or expr.is_literal:
return True
return False