1
0
mirror of https://gitlab.com/MoonTestUse1/AdministrationItDepartmens.git synced 2025-08-14 00:25:46 +02:00

Все подряд

This commit is contained in:
MoonTestUse1
2024-12-31 02:37:57 +06:00
parent 8e53bb6cb2
commit d5780b2eab
3258 changed files with 1087440 additions and 268 deletions

View File

@@ -0,0 +1,6 @@
from . import mssql
from . import mysql
from . import oracle
from . import postgresql
from . import sqlite
from .impl import DefaultImpl as DefaultImpl

View File

@@ -0,0 +1,325 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
from typing import Any
from typing import ClassVar
from typing import Dict
from typing import Generic
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import UniqueConstraint
from typing_extensions import TypeGuard
from .. import util
from ..util import sqla_compat
if TYPE_CHECKING:
from typing import Literal
from alembic.autogenerate.api import AutogenContext
from alembic.ddl.impl import DefaultImpl
CompareConstraintType = Union[Constraint, Index]
_C = TypeVar("_C", bound=CompareConstraintType)
_clsreg: Dict[str, Type[_constraint_sig]] = {}
class ComparisonResult(NamedTuple):
status: Literal["equal", "different", "skip"]
message: str
@property
def is_equal(self) -> bool:
return self.status == "equal"
@property
def is_different(self) -> bool:
return self.status == "different"
@property
def is_skip(self) -> bool:
return self.status == "skip"
@classmethod
def Equal(cls) -> ComparisonResult:
"""the constraints are equal."""
return cls("equal", "The two constraints are equal")
@classmethod
def Different(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
"""the constraints are different for the provided reason(s)."""
return cls("different", ", ".join(util.to_list(reason)))
@classmethod
def Skip(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
"""the constraint cannot be compared for the provided reason(s).
The message is logged, but the constraints will be otherwise
considered equal, meaning that no migration command will be
generated.
"""
return cls("skip", ", ".join(util.to_list(reason)))
class _constraint_sig(Generic[_C]):
const: _C
_sig: Tuple[Any, ...]
name: Optional[sqla_compat._ConstraintNameDefined]
impl: DefaultImpl
_is_index: ClassVar[bool] = False
_is_fk: ClassVar[bool] = False
_is_uq: ClassVar[bool] = False
_is_metadata: bool
def __init_subclass__(cls) -> None:
cls._register()
@classmethod
def _register(cls):
raise NotImplementedError()
def __init__(
self, is_metadata: bool, impl: DefaultImpl, const: _C
) -> None:
raise NotImplementedError()
def compare_to_reflected(
self, other: _constraint_sig[Any]
) -> ComparisonResult:
assert self.impl is other.impl
assert self._is_metadata
assert not other._is_metadata
return self._compare_to_reflected(other)
def _compare_to_reflected(
self, other: _constraint_sig[_C]
) -> ComparisonResult:
raise NotImplementedError()
@classmethod
def from_constraint(
cls, is_metadata: bool, impl: DefaultImpl, constraint: _C
) -> _constraint_sig[_C]:
# these could be cached by constraint/impl, however, if the
# constraint is modified in place, then the sig is wrong. the mysql
# impl currently does this, and if we fixed that we can't be sure
# someone else might do it too, so play it safe.
sig = _clsreg[constraint.__visit_name__](is_metadata, impl, constraint)
return sig
def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
return sqla_compat._get_constraint_final_name(
self.const, context.dialect
)
@util.memoized_property
def is_named(self):
return sqla_compat._constraint_is_named(self.const, self.impl.dialect)
@util.memoized_property
def unnamed(self) -> Tuple[Any, ...]:
return self._sig
@util.memoized_property
def unnamed_no_options(self) -> Tuple[Any, ...]:
raise NotImplementedError()
@util.memoized_property
def _full_sig(self) -> Tuple[Any, ...]:
return (self.name,) + self.unnamed
def __eq__(self, other) -> bool:
return self._full_sig == other._full_sig
def __ne__(self, other) -> bool:
return self._full_sig != other._full_sig
def __hash__(self) -> int:
return hash(self._full_sig)
class _uq_constraint_sig(_constraint_sig[UniqueConstraint]):
_is_uq = True
@classmethod
def _register(cls) -> None:
_clsreg["unique_constraint"] = cls
is_unique = True
def __init__(
self,
is_metadata: bool,
impl: DefaultImpl,
const: UniqueConstraint,
) -> None:
self.impl = impl
self.const = const
self.name = sqla_compat.constraint_name_or_none(const.name)
self._sig = tuple(sorted([col.name for col in const.columns]))
self._is_metadata = is_metadata
@property
def column_names(self) -> Tuple[str, ...]:
return tuple([col.name for col in self.const.columns])
def _compare_to_reflected(
self, other: _constraint_sig[_C]
) -> ComparisonResult:
assert self._is_metadata
metadata_obj = self
conn_obj = other
assert is_uq_sig(conn_obj)
return self.impl.compare_unique_constraint(
metadata_obj.const, conn_obj.const
)
class _ix_constraint_sig(_constraint_sig[Index]):
_is_index = True
name: sqla_compat._ConstraintName
@classmethod
def _register(cls) -> None:
_clsreg["index"] = cls
def __init__(
self, is_metadata: bool, impl: DefaultImpl, const: Index
) -> None:
self.impl = impl
self.const = const
self.name = const.name
self.is_unique = bool(const.unique)
self._is_metadata = is_metadata
def _compare_to_reflected(
self, other: _constraint_sig[_C]
) -> ComparisonResult:
assert self._is_metadata
metadata_obj = self
conn_obj = other
assert is_index_sig(conn_obj)
return self.impl.compare_indexes(metadata_obj.const, conn_obj.const)
@util.memoized_property
def has_expressions(self):
return sqla_compat.is_expression_index(self.const)
@util.memoized_property
def column_names(self) -> Tuple[str, ...]:
return tuple([col.name for col in self.const.columns])
@util.memoized_property
def column_names_optional(self) -> Tuple[Optional[str], ...]:
return tuple(
[getattr(col, "name", None) for col in self.const.expressions]
)
@util.memoized_property
def is_named(self):
return True
@util.memoized_property
def unnamed(self):
return (self.is_unique,) + self.column_names_optional
class _fk_constraint_sig(_constraint_sig[ForeignKeyConstraint]):
_is_fk = True
@classmethod
def _register(cls) -> None:
_clsreg["foreign_key_constraint"] = cls
def __init__(
self,
is_metadata: bool,
impl: DefaultImpl,
const: ForeignKeyConstraint,
) -> None:
self._is_metadata = is_metadata
self.impl = impl
self.const = const
self.name = sqla_compat.constraint_name_or_none(const.name)
(
self.source_schema,
self.source_table,
self.source_columns,
self.target_schema,
self.target_table,
self.target_columns,
onupdate,
ondelete,
deferrable,
initially,
) = sqla_compat._fk_spec(const)
self._sig: Tuple[Any, ...] = (
self.source_schema,
self.source_table,
tuple(self.source_columns),
self.target_schema,
self.target_table,
tuple(self.target_columns),
) + (
(None if onupdate.lower() == "no action" else onupdate.lower())
if onupdate
else None,
(None if ondelete.lower() == "no action" else ondelete.lower())
if ondelete
else None,
# convert initially + deferrable into one three-state value
"initially_deferrable"
if initially and initially.lower() == "deferred"
else "deferrable"
if deferrable
else "not deferrable",
)
@util.memoized_property
def unnamed_no_options(self):
return (
self.source_schema,
self.source_table,
tuple(self.source_columns),
self.target_schema,
self.target_table,
tuple(self.target_columns),
)
def is_index_sig(sig: _constraint_sig) -> TypeGuard[_ix_constraint_sig]:
return sig._is_index
def is_uq_sig(sig: _constraint_sig) -> TypeGuard[_uq_constraint_sig]:
return sig._is_uq
def is_fk_sig(sig: _constraint_sig) -> TypeGuard[_fk_constraint_sig]:
return sig._is_fk

View File

@@ -0,0 +1,335 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import functools
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import exc
from sqlalchemy import Integer
from sqlalchemy import types as sqltypes
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import Column
from sqlalchemy.schema import DDLElement
from sqlalchemy.sql.elements import quoted_name
from ..util.sqla_compat import _columns_for_constraint # noqa
from ..util.sqla_compat import _find_columns # noqa
from ..util.sqla_compat import _fk_spec # noqa
from ..util.sqla_compat import _is_type_bound # noqa
from ..util.sqla_compat import _table_for_constraint # noqa
if TYPE_CHECKING:
from typing import Any
from sqlalchemy.sql.compiler import Compiled
from sqlalchemy.sql.compiler import DDLCompiler
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.functions import Function
from sqlalchemy.sql.schema import FetchedValue
from sqlalchemy.sql.type_api import TypeEngine
from .impl import DefaultImpl
from ..util.sqla_compat import Computed
from ..util.sqla_compat import Identity
_ServerDefault = Union["TextClause", "FetchedValue", "Function[Any]", str]
class AlterTable(DDLElement):
"""Represent an ALTER TABLE statement.
Only the string name and optional schema name of the table
is required, not a full Table object.
"""
def __init__(
self,
table_name: str,
schema: Optional[Union[quoted_name, str]] = None,
) -> None:
self.table_name = table_name
self.schema = schema
class RenameTable(AlterTable):
def __init__(
self,
old_table_name: str,
new_table_name: Union[quoted_name, str],
schema: Optional[Union[quoted_name, str]] = None,
) -> None:
super().__init__(old_table_name, schema=schema)
self.new_table_name = new_table_name
class AlterColumn(AlterTable):
def __init__(
self,
name: str,
column_name: str,
schema: Optional[str] = None,
existing_type: Optional[TypeEngine] = None,
existing_nullable: Optional[bool] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_comment: Optional[str] = None,
) -> None:
super().__init__(name, schema=schema)
self.column_name = column_name
self.existing_type = (
sqltypes.to_instance(existing_type)
if existing_type is not None
else None
)
self.existing_nullable = existing_nullable
self.existing_server_default = existing_server_default
self.existing_comment = existing_comment
class ColumnNullable(AlterColumn):
def __init__(
self, name: str, column_name: str, nullable: bool, **kw
) -> None:
super().__init__(name, column_name, **kw)
self.nullable = nullable
class ColumnType(AlterColumn):
def __init__(
self, name: str, column_name: str, type_: TypeEngine, **kw
) -> None:
super().__init__(name, column_name, **kw)
self.type_ = sqltypes.to_instance(type_)
class ColumnName(AlterColumn):
def __init__(
self, name: str, column_name: str, newname: str, **kw
) -> None:
super().__init__(name, column_name, **kw)
self.newname = newname
class ColumnDefault(AlterColumn):
def __init__(
self,
name: str,
column_name: str,
default: Optional[_ServerDefault],
**kw,
) -> None:
super().__init__(name, column_name, **kw)
self.default = default
class ComputedColumnDefault(AlterColumn):
def __init__(
self, name: str, column_name: str, default: Optional[Computed], **kw
) -> None:
super().__init__(name, column_name, **kw)
self.default = default
class IdentityColumnDefault(AlterColumn):
def __init__(
self,
name: str,
column_name: str,
default: Optional[Identity],
impl: DefaultImpl,
**kw,
) -> None:
super().__init__(name, column_name, **kw)
self.default = default
self.impl = impl
class AddColumn(AlterTable):
def __init__(
self,
name: str,
column: Column[Any],
schema: Optional[Union[quoted_name, str]] = None,
) -> None:
super().__init__(name, schema=schema)
self.column = column
class DropColumn(AlterTable):
def __init__(
self, name: str, column: Column[Any], schema: Optional[str] = None
) -> None:
super().__init__(name, schema=schema)
self.column = column
class ColumnComment(AlterColumn):
def __init__(
self, name: str, column_name: str, comment: Optional[str], **kw
) -> None:
super().__init__(name, column_name, **kw)
self.comment = comment
@compiles(RenameTable) # type: ignore[misc]
def visit_rename_table(
element: RenameTable, compiler: DDLCompiler, **kw
) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, element.schema),
)
@compiles(AddColumn) # type: ignore[misc]
def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
add_column(compiler, element.column, **kw),
)
@compiles(DropColumn) # type: ignore[misc]
def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
drop_column(compiler, element.column.name, **kw),
)
@compiles(ColumnNullable) # type: ignore[misc]
def visit_column_nullable(
element: ColumnNullable, compiler: DDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"DROP NOT NULL" if element.nullable else "SET NOT NULL",
)
@compiles(ColumnType) # type: ignore[misc]
def visit_column_type(element: ColumnType, compiler: DDLCompiler, **kw) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"TYPE %s" % format_type(compiler, element.type_),
)
@compiles(ColumnName) # type: ignore[misc]
def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str:
return "%s RENAME %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
format_column_name(compiler, element.newname),
)
@compiles(ColumnDefault) # type: ignore[misc]
def visit_column_default(
element: ColumnDefault, compiler: DDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"SET DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DROP DEFAULT",
)
@compiles(ComputedColumnDefault) # type: ignore[misc]
def visit_computed_column(
element: ComputedColumnDefault, compiler: DDLCompiler, **kw
):
raise exc.CompileError(
'Adding or removing a "computed" construct, e.g. GENERATED '
"ALWAYS AS, to or from an existing column is not supported."
)
@compiles(IdentityColumnDefault) # type: ignore[misc]
def visit_identity_column(
element: IdentityColumnDefault, compiler: DDLCompiler, **kw
):
raise exc.CompileError(
'Adding, removing or modifying an "identity" construct, '
"e.g. GENERATED AS IDENTITY, to or from an existing "
"column is not supported in this dialect."
)
def quote_dotted(
name: Union[quoted_name, str], quote: functools.partial
) -> Union[quoted_name, str]:
"""quote the elements of a dotted name"""
if isinstance(name, quoted_name):
return quote(name)
result = ".".join([quote(x) for x in name.split(".")])
return result
def format_table_name(
compiler: Compiled,
name: Union[quoted_name, str],
schema: Optional[Union[quoted_name, str]],
) -> Union[quoted_name, str]:
quote = functools.partial(compiler.preparer.quote)
if schema:
return quote_dotted(schema, quote) + "." + quote(name)
else:
return quote(name)
def format_column_name(
compiler: DDLCompiler, name: Optional[Union[quoted_name, str]]
) -> Union[quoted_name, str]:
return compiler.preparer.quote(name) # type: ignore[arg-type]
def format_server_default(
compiler: DDLCompiler,
default: Optional[_ServerDefault],
) -> str:
return compiler.get_column_default_string(
Column("x", Integer, server_default=default)
)
def format_type(compiler: DDLCompiler, type_: TypeEngine) -> str:
return compiler.dialect.type_compiler.process(type_)
def alter_table(
compiler: DDLCompiler,
name: str,
schema: Optional[str],
) -> str:
return "ALTER TABLE %s" % format_table_name(compiler, name, schema)
def drop_column(compiler: DDLCompiler, name: str, **kw) -> str:
return "DROP COLUMN %s" % format_column_name(compiler, name)
def alter_column(compiler: DDLCompiler, name: str) -> str:
return "ALTER COLUMN %s" % format_column_name(compiler, name)
def add_column(compiler: DDLCompiler, column: Column[Any], **kw) -> str:
text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
const = " ".join(
compiler.process(constraint) for constraint in column.constraints
)
if const:
text += " " + const
return text

View File

@@ -0,0 +1,844 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import logging
import re
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
from typing import Mapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import cast
from sqlalchemy import schema
from sqlalchemy import text
from . import _autogen
from . import base
from ._autogen import _constraint_sig as _constraint_sig
from ._autogen import ComparisonResult as ComparisonResult
from .. import util
from ..util import sqla_compat
if TYPE_CHECKING:
from typing import Literal
from typing import TextIO
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql import Executable
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.schema import UniqueConstraint
from sqlalchemy.sql.selectable import TableClause
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
from ..autogenerate.api import AutogenContext
from ..operations.batch import ApplyBatchImpl
from ..operations.batch import BatchOperationsImpl
log = logging.getLogger(__name__)
class ImplMeta(type):
def __init__(
cls,
classname: str,
bases: Tuple[Type[DefaultImpl]],
dict_: Dict[str, Any],
):
newtype = type.__init__(cls, classname, bases, dict_)
if "__dialect__" in dict_:
_impls[dict_["__dialect__"]] = cls # type: ignore[assignment]
return newtype
_impls: Dict[str, Type[DefaultImpl]] = {}
class DefaultImpl(metaclass=ImplMeta):
"""Provide the entrypoint for major migration operations,
including database-specific behavioral variances.
While individual SQL/DDL constructs already provide
for database-specific implementations, variances here
allow for entirely different sequences of operations
to take place for a particular migration, such as
SQL Server's special 'IDENTITY INSERT' step for
bulk inserts.
"""
__dialect__ = "default"
transactional_ddl = False
command_terminator = ";"
type_synonyms: Tuple[Set[str], ...] = ({"NUMERIC", "DECIMAL"},)
type_arg_extract: Sequence[str] = ()
# These attributes are deprecated in SQLAlchemy via #10247. They need to
# be ignored to support older version that did not use dialect kwargs.
# They only apply to Oracle and are replaced by oracle_order,
# oracle_on_null
identity_attrs_ignore: Tuple[str, ...] = ("order", "on_null")
def __init__(
self,
dialect: Dialect,
connection: Optional[Connection],
as_sql: bool,
transactional_ddl: Optional[bool],
output_buffer: Optional[TextIO],
context_opts: Dict[str, Any],
) -> None:
self.dialect = dialect
self.connection = connection
self.as_sql = as_sql
self.literal_binds = context_opts.get("literal_binds", False)
self.output_buffer = output_buffer
self.memo: dict = {}
self.context_opts = context_opts
if transactional_ddl is not None:
self.transactional_ddl = transactional_ddl
if self.literal_binds:
if not self.as_sql:
raise util.CommandError(
"Can't use literal_binds setting without as_sql mode"
)
@classmethod
def get_by_dialect(cls, dialect: Dialect) -> Type[DefaultImpl]:
return _impls[dialect.name]
def static_output(self, text: str) -> None:
assert self.output_buffer is not None
self.output_buffer.write(text + "\n\n")
self.output_buffer.flush()
def requires_recreate_in_batch(
self, batch_op: BatchOperationsImpl
) -> bool:
"""Return True if the given :class:`.BatchOperationsImpl`
would need the table to be recreated and copied in order to
proceed.
Normally, only returns True on SQLite when operations other
than add_column are present.
"""
return False
def prep_table_for_batch(
self, batch_impl: ApplyBatchImpl, table: Table
) -> None:
"""perform any operations needed on a table before a new
one is created to replace it in batch mode.
the PG dialect uses this to drop constraints on the table
before the new one uses those same names.
"""
@property
def bind(self) -> Optional[Connection]:
return self.connection
def _exec(
self,
construct: Union[Executable, str],
execution_options: Optional[dict[str, Any]] = None,
multiparams: Sequence[dict] = (),
params: Dict[str, Any] = util.immutabledict(),
) -> Optional[CursorResult]:
if isinstance(construct, str):
construct = text(construct)
if self.as_sql:
if multiparams or params:
# TODO: coverage
raise Exception("Execution arguments not allowed with as_sql")
compile_kw: dict[str, Any]
if self.literal_binds and not isinstance(
construct, schema.DDLElement
):
compile_kw = dict(compile_kwargs={"literal_binds": True})
else:
compile_kw = {}
if TYPE_CHECKING:
assert isinstance(construct, ClauseElement)
compiled = construct.compile(dialect=self.dialect, **compile_kw)
self.static_output(
str(compiled).replace("\t", " ").strip()
+ self.command_terminator
)
return None
else:
conn = self.connection
assert conn is not None
if execution_options:
conn = conn.execution_options(**execution_options)
if params:
assert isinstance(multiparams, tuple)
multiparams += (params,)
return conn.execute(construct, multiparams)
def execute(
self,
sql: Union[Executable, str],
execution_options: Optional[dict[str, Any]] = None,
) -> None:
self._exec(sql, execution_options)
def alter_column(
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Union[_ServerDefault, Literal[False]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
autoincrement: Optional[bool] = None,
comment: Optional[Union[str, Literal[False]]] = False,
existing_comment: Optional[str] = None,
existing_type: Optional[TypeEngine] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_nullable: Optional[bool] = None,
existing_autoincrement: Optional[bool] = None,
**kw: Any,
) -> None:
if autoincrement is not None or existing_autoincrement is not None:
util.warn(
"autoincrement and existing_autoincrement "
"only make sense for MySQL",
stacklevel=3,
)
if nullable is not None:
self._exec(
base.ColumnNullable(
table_name,
column_name,
nullable,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_comment=existing_comment,
)
)
if server_default is not False:
kw = {}
cls_: Type[
Union[
base.ComputedColumnDefault,
base.IdentityColumnDefault,
base.ColumnDefault,
]
]
if sqla_compat._server_default_is_computed(
server_default, existing_server_default
):
cls_ = base.ComputedColumnDefault
elif sqla_compat._server_default_is_identity(
server_default, existing_server_default
):
cls_ = base.IdentityColumnDefault
kw["impl"] = self
else:
cls_ = base.ColumnDefault
self._exec(
cls_(
table_name,
column_name,
server_default, # type:ignore[arg-type]
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_comment=existing_comment,
**kw,
)
)
if type_ is not None:
self._exec(
base.ColumnType(
table_name,
column_name,
type_,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_comment=existing_comment,
)
)
if comment is not False:
self._exec(
base.ColumnComment(
table_name,
column_name,
comment,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_comment=existing_comment,
)
)
# do the new name last ;)
if name is not None:
self._exec(
base.ColumnName(
table_name,
column_name,
name,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
)
)
def add_column(
self,
table_name: str,
column: Column[Any],
schema: Optional[Union[str, quoted_name]] = None,
) -> None:
self._exec(base.AddColumn(table_name, column, schema=schema))
def drop_column(
self,
table_name: str,
column: Column[Any],
schema: Optional[str] = None,
**kw,
) -> None:
self._exec(base.DropColumn(table_name, column, schema=schema))
def add_constraint(self, const: Any) -> None:
if const._create_rule is None or const._create_rule(self):
self._exec(schema.AddConstraint(const))
def drop_constraint(self, const: Constraint) -> None:
self._exec(schema.DropConstraint(const))
def rename_table(
self,
old_table_name: str,
new_table_name: Union[str, quoted_name],
schema: Optional[Union[str, quoted_name]] = None,
) -> None:
self._exec(
base.RenameTable(old_table_name, new_table_name, schema=schema)
)
def create_table(self, table: Table) -> None:
table.dispatch.before_create(
table, self.connection, checkfirst=False, _ddl_runner=self
)
self._exec(schema.CreateTable(table))
table.dispatch.after_create(
table, self.connection, checkfirst=False, _ddl_runner=self
)
for index in table.indexes:
self._exec(schema.CreateIndex(index))
with_comment = (
self.dialect.supports_comments and not self.dialect.inline_comments
)
comment = table.comment
if comment and with_comment:
self.create_table_comment(table)
for column in table.columns:
comment = column.comment
if comment and with_comment:
self.create_column_comment(column)
def drop_table(self, table: Table) -> None:
table.dispatch.before_drop(
table, self.connection, checkfirst=False, _ddl_runner=self
)
self._exec(schema.DropTable(table))
table.dispatch.after_drop(
table, self.connection, checkfirst=False, _ddl_runner=self
)
def create_index(self, index: Index, **kw: Any) -> None:
self._exec(schema.CreateIndex(index, **kw))
def create_table_comment(self, table: Table) -> None:
self._exec(schema.SetTableComment(table))
def drop_table_comment(self, table: Table) -> None:
self._exec(schema.DropTableComment(table))
def create_column_comment(self, column: ColumnElement[Any]) -> None:
self._exec(schema.SetColumnComment(column))
def drop_index(self, index: Index, **kw: Any) -> None:
self._exec(schema.DropIndex(index, **kw))
def bulk_insert(
self,
table: Union[TableClause, Table],
rows: List[dict],
multiinsert: bool = True,
) -> None:
if not isinstance(rows, list):
raise TypeError("List expected")
elif rows and not isinstance(rows[0], dict):
raise TypeError("List of dictionaries expected")
if self.as_sql:
for row in rows:
self._exec(
sqla_compat._insert_inline(table).values(
**{
k: sqla_compat._literal_bindparam(
k, v, type_=table.c[k].type
)
if not isinstance(
v, sqla_compat._literal_bindparam
)
else v
for k, v in row.items()
}
)
)
else:
if rows:
if multiinsert:
self._exec(
sqla_compat._insert_inline(table), multiparams=rows
)
else:
for row in rows:
self._exec(
sqla_compat._insert_inline(table).values(**row)
)
def _tokenize_column_type(self, column: Column) -> Params:
definition: str
definition = self.dialect.type_compiler.process(column.type).lower()
# tokenize the SQLAlchemy-generated version of a type, so that
# the two can be compared.
#
# examples:
# NUMERIC(10, 5)
# TIMESTAMP WITH TIMEZONE
# INTEGER UNSIGNED
# INTEGER (10) UNSIGNED
# INTEGER(10) UNSIGNED
# varchar character set utf8
#
tokens: List[str] = re.findall(r"[\w\-_]+|\(.+?\)", definition)
term_tokens: List[str] = []
paren_term = None
for token in tokens:
if re.match(r"^\(.*\)$", token):
paren_term = token
else:
term_tokens.append(token)
params = Params(term_tokens[0], term_tokens[1:], [], {})
if paren_term:
term: str
for term in re.findall("[^(),]+", paren_term):
if "=" in term:
key, val = term.split("=")
params.kwargs[key.strip()] = val.strip()
else:
params.args.append(term.strip())
return params
def _column_types_match(
self, inspector_params: Params, metadata_params: Params
) -> bool:
if inspector_params.token0 == metadata_params.token0:
return True
synonyms = [{t.lower() for t in batch} for batch in self.type_synonyms]
inspector_all_terms = " ".join(
[inspector_params.token0] + inspector_params.tokens
)
metadata_all_terms = " ".join(
[metadata_params.token0] + metadata_params.tokens
)
for batch in synonyms:
if {inspector_all_terms, metadata_all_terms}.issubset(batch) or {
inspector_params.token0,
metadata_params.token0,
}.issubset(batch):
return True
return False
def _column_args_match(
self, inspected_params: Params, meta_params: Params
) -> bool:
"""We want to compare column parameters. However, we only want
to compare parameters that are set. If they both have `collation`,
we want to make sure they are the same. However, if only one
specifies it, dont flag it for being less specific
"""
if (
len(meta_params.tokens) == len(inspected_params.tokens)
and meta_params.tokens != inspected_params.tokens
):
return False
if (
len(meta_params.args) == len(inspected_params.args)
and meta_params.args != inspected_params.args
):
return False
insp = " ".join(inspected_params.tokens).lower()
meta = " ".join(meta_params.tokens).lower()
for reg in self.type_arg_extract:
mi = re.search(reg, insp)
mm = re.search(reg, meta)
if mi and mm and mi.group(1) != mm.group(1):
return False
return True
def compare_type(
self, inspector_column: Column[Any], metadata_column: Column
) -> bool:
"""Returns True if there ARE differences between the types of the two
columns. Takes impl.type_synonyms into account between retrospected
and metadata types
"""
inspector_params = self._tokenize_column_type(inspector_column)
metadata_params = self._tokenize_column_type(metadata_column)
if not self._column_types_match(inspector_params, metadata_params):
return True
if not self._column_args_match(inspector_params, metadata_params):
return True
return False
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
return rendered_inspector_default != rendered_metadata_default
def correct_for_autogen_constraints(
self,
conn_uniques: Set[UniqueConstraint],
conn_indexes: Set[Index],
metadata_unique_constraints: Set[UniqueConstraint],
metadata_indexes: Set[Index],
) -> None:
pass
def cast_for_batch_migrate(self, existing, existing_transfer, new_type):
if existing.type._type_affinity is not new_type._type_affinity:
existing_transfer["expr"] = cast(
existing_transfer["expr"], new_type
)
def render_ddl_sql_expr(
self, expr: ClauseElement, is_server_default: bool = False, **kw: Any
) -> str:
"""Render a SQL expression that is typically a server default,
index expression, etc.
"""
compile_kw = {"literal_binds": True, "include_table": False}
return str(
expr.compile(dialect=self.dialect, compile_kwargs=compile_kw)
)
def _compat_autogen_column_reflect(self, inspector: Inspector) -> Callable:
return self.autogen_column_reflect
def correct_for_autogen_foreignkeys(
self,
conn_fks: Set[ForeignKeyConstraint],
metadata_fks: Set[ForeignKeyConstraint],
) -> None:
pass
def autogen_column_reflect(self, inspector, table, column_info):
"""A hook that is attached to the 'column_reflect' event for when
a Table is reflected from the database during the autogenerate
process.
Dialects can elect to modify the information gathered here.
"""
def start_migrations(self) -> None:
"""A hook called when :meth:`.EnvironmentContext.run_migrations`
is called.
Implementations can set up per-migration-run state here.
"""
def emit_begin(self) -> None:
"""Emit the string ``BEGIN``, or the backend-specific
equivalent, on the current connection context.
This is used in offline mode and typically
via :meth:`.EnvironmentContext.begin_transaction`.
"""
self.static_output("BEGIN" + self.command_terminator)
def emit_commit(self) -> None:
"""Emit the string ``COMMIT``, or the backend-specific
equivalent, on the current connection context.
This is used in offline mode and typically
via :meth:`.EnvironmentContext.begin_transaction`.
"""
self.static_output("COMMIT" + self.command_terminator)
def render_type(
self, type_obj: TypeEngine, autogen_context: AutogenContext
) -> Union[str, Literal[False]]:
return False
def _compare_identity_default(self, metadata_identity, inspector_identity):
# ignored contains the attributes that were not considered
# because assumed to their default values in the db.
diff, ignored = _compare_identity_options(
metadata_identity,
inspector_identity,
sqla_compat.Identity(),
skip={"always"},
)
meta_always = getattr(metadata_identity, "always", None)
inspector_always = getattr(inspector_identity, "always", None)
# None and False are the same in this comparison
if bool(meta_always) != bool(inspector_always):
diff.add("always")
diff.difference_update(self.identity_attrs_ignore)
# returns 3 values:
return (
# different identity attributes
diff,
# ignored identity attributes
ignored,
# if the two identity should be considered different
bool(diff) or bool(metadata_identity) != bool(inspector_identity),
)
def _compare_index_unique(
self, metadata_index: Index, reflected_index: Index
) -> Optional[str]:
conn_unique = bool(reflected_index.unique)
meta_unique = bool(metadata_index.unique)
if conn_unique != meta_unique:
return f"unique={conn_unique} to unique={meta_unique}"
else:
return None
def _create_metadata_constraint_sig(
self, constraint: _autogen._C, **opts: Any
) -> _constraint_sig[_autogen._C]:
return _constraint_sig.from_constraint(True, self, constraint, **opts)
def _create_reflected_constraint_sig(
self, constraint: _autogen._C, **opts: Any
) -> _constraint_sig[_autogen._C]:
return _constraint_sig.from_constraint(False, self, constraint, **opts)
def compare_indexes(
self,
metadata_index: Index,
reflected_index: Index,
) -> ComparisonResult:
"""Compare two indexes by comparing the signature generated by
``create_index_sig``.
This method returns a ``ComparisonResult``.
"""
msg: List[str] = []
unique_msg = self._compare_index_unique(
metadata_index, reflected_index
)
if unique_msg:
msg.append(unique_msg)
m_sig = self._create_metadata_constraint_sig(metadata_index)
r_sig = self._create_reflected_constraint_sig(reflected_index)
assert _autogen.is_index_sig(m_sig)
assert _autogen.is_index_sig(r_sig)
# The assumption is that the index have no expression
for sig in m_sig, r_sig:
if sig.has_expressions:
log.warning(
"Generating approximate signature for index %s. "
"The dialect "
"implementation should either skip expression indexes "
"or provide a custom implementation.",
sig.const,
)
if m_sig.column_names != r_sig.column_names:
msg.append(
f"expression {r_sig.column_names} to {m_sig.column_names}"
)
if msg:
return ComparisonResult.Different(msg)
else:
return ComparisonResult.Equal()
def compare_unique_constraint(
self,
metadata_constraint: UniqueConstraint,
reflected_constraint: UniqueConstraint,
) -> ComparisonResult:
"""Compare two unique constraints by comparing the two signatures.
The arguments are two tuples that contain the unique constraint and
the signatures generated by ``create_unique_constraint_sig``.
This method returns a ``ComparisonResult``.
"""
metadata_tup = self._create_metadata_constraint_sig(
metadata_constraint
)
reflected_tup = self._create_reflected_constraint_sig(
reflected_constraint
)
meta_sig = metadata_tup.unnamed
conn_sig = reflected_tup.unnamed
if conn_sig != meta_sig:
return ComparisonResult.Different(
f"expression {conn_sig} to {meta_sig}"
)
else:
return ComparisonResult.Equal()
def _skip_functional_indexes(self, metadata_indexes, conn_indexes):
conn_indexes_by_name = {c.name: c for c in conn_indexes}
for idx in list(metadata_indexes):
if idx.name in conn_indexes_by_name:
continue
iex = sqla_compat.is_expression_index(idx)
if iex:
util.warn(
"autogenerate skipping metadata-specified "
"expression-based index "
f"{idx.name!r}; dialect {self.__dialect__!r} under "
f"SQLAlchemy {sqla_compat.sqlalchemy_version} can't "
"reflect these indexes so they can't be compared"
)
metadata_indexes.discard(idx)
def adjust_reflected_dialect_options(
self, reflected_object: Dict[str, Any], kind: str
) -> Dict[str, Any]:
return reflected_object.get("dialect_options", {})
class Params(NamedTuple):
token0: str
tokens: List[str]
args: List[str]
kwargs: Dict[str, str]
def _compare_identity_options(
metadata_io: Union[schema.Identity, schema.Sequence, None],
inspector_io: Union[schema.Identity, schema.Sequence, None],
default_io: Union[schema.Identity, schema.Sequence],
skip: Set[str],
):
# this can be used for identity or sequence compare.
# default_io is an instance of IdentityOption with all attributes to the
# default value.
meta_d = sqla_compat._get_identity_options_dict(metadata_io)
insp_d = sqla_compat._get_identity_options_dict(inspector_io)
diff = set()
ignored_attr = set()
def check_dicts(
meta_dict: Mapping[str, Any],
insp_dict: Mapping[str, Any],
default_dict: Mapping[str, Any],
attrs: Iterable[str],
):
for attr in set(attrs).difference(skip):
meta_value = meta_dict.get(attr)
insp_value = insp_dict.get(attr)
if insp_value != meta_value:
default_value = default_dict.get(attr)
if meta_value == default_value:
ignored_attr.add(attr)
else:
diff.add(attr)
check_dicts(
meta_d,
insp_d,
sqla_compat._get_identity_options_dict(default_io),
set(meta_d).union(insp_d),
)
if sqla_compat.identity_has_dialect_kwargs:
# use only the dialect kwargs in inspector_io since metadata_io
# can have options for many backends
check_dicts(
getattr(metadata_io, "dialect_kwargs", {}),
getattr(inspector_io, "dialect_kwargs", {}),
default_io.dialect_kwargs, # type: ignore[union-attr]
getattr(inspector_io, "dialect_kwargs", {}),
)
return diff, ignored_attr

View File

@@ -0,0 +1,419 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import re
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import types as sqltypes
from sqlalchemy.schema import Column
from sqlalchemy.schema import CreateIndex
from sqlalchemy.sql.base import Executable
from sqlalchemy.sql.elements import ClauseElement
from .base import AddColumn
from .base import alter_column
from .base import alter_table
from .base import ColumnDefault
from .base import ColumnName
from .base import ColumnNullable
from .base import ColumnType
from .base import format_column_name
from .base import format_server_default
from .base import format_table_name
from .base import format_type
from .base import RenameTable
from .impl import DefaultImpl
from .. import util
from ..util import sqla_compat
from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy.dialects.mssql.base import MSDDLCompiler
from sqlalchemy.dialects.mssql.base import MSSQLCompiler
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.selectable import TableClause
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
class MSSQLImpl(DefaultImpl):
__dialect__ = "mssql"
transactional_ddl = True
batch_separator = "GO"
type_synonyms = DefaultImpl.type_synonyms + ({"VARCHAR", "NVARCHAR"},)
identity_attrs_ignore = DefaultImpl.identity_attrs_ignore + (
"minvalue",
"maxvalue",
"nominvalue",
"nomaxvalue",
"cycle",
"cache",
)
def __init__(self, *arg, **kw) -> None:
super().__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
"mssql_batch_separator", self.batch_separator
)
def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]:
result = super()._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
return result
def emit_begin(self) -> None:
self.static_output("BEGIN TRANSACTION" + self.command_terminator)
def emit_commit(self) -> None:
super().emit_commit()
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
def alter_column( # type:ignore[override]
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Optional[
Union[_ServerDefault, Literal[False]]
] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
existing_type: Optional[TypeEngine] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_nullable: Optional[bool] = None,
**kw: Any,
) -> None:
if nullable is not None:
if type_ is not None:
# the NULL/NOT NULL alter will handle
# the type alteration
existing_type = type_
type_ = None
elif existing_type is None:
raise util.CommandError(
"MS-SQL ALTER COLUMN operations "
"with NULL or NOT NULL require the "
"existing_type or a new type_ be passed."
)
elif existing_nullable is not None and type_ is not None:
nullable = existing_nullable
# the NULL/NOT NULL alter will handle
# the type alteration
existing_type = type_
type_ = None
elif type_ is not None:
util.warn(
"MS-SQL ALTER COLUMN operations that specify type_= "
"should also specify a nullable= or "
"existing_nullable= argument to avoid implicit conversion "
"of NOT NULL columns to NULL."
)
used_default = False
if sqla_compat._server_default_is_identity(
server_default, existing_server_default
) or sqla_compat._server_default_is_computed(
server_default, existing_server_default
):
used_default = True
kw["server_default"] = server_default
kw["existing_server_default"] = existing_server_default
super().alter_column(
table_name,
column_name,
nullable=nullable,
type_=type_,
schema=schema,
existing_type=existing_type,
existing_nullable=existing_nullable,
**kw,
)
if server_default is not False and used_default is False:
if existing_server_default is not False or server_default is None:
self._exec(
_ExecDropConstraint(
table_name,
column_name,
"sys.default_constraints",
schema,
)
)
if server_default is not None:
super().alter_column(
table_name,
column_name,
schema=schema,
server_default=server_default,
)
if name is not None:
super().alter_column(
table_name, column_name, schema=schema, name=name
)
def create_index(self, index: Index, **kw: Any) -> None:
# this likely defaults to None if not present, so get()
# should normally not return the default value. being
# defensive in any case
mssql_include = index.kwargs.get("mssql_include", None) or ()
assert index.table is not None
for col in mssql_include:
if col not in index.table.c:
index.table.append_column(Column(col, sqltypes.NullType))
self._exec(CreateIndex(index, **kw))
def bulk_insert( # type:ignore[override]
self, table: Union[TableClause, Table], rows: List[dict], **kw: Any
) -> None:
if self.as_sql:
self._exec(
"SET IDENTITY_INSERT %s ON"
% self.dialect.identifier_preparer.format_table(table)
)
super().bulk_insert(table, rows, **kw)
self._exec(
"SET IDENTITY_INSERT %s OFF"
% self.dialect.identifier_preparer.format_table(table)
)
else:
super().bulk_insert(table, rows, **kw)
def drop_column(
self,
table_name: str,
column: Column[Any],
schema: Optional[str] = None,
**kw,
) -> None:
drop_default = kw.pop("mssql_drop_default", False)
if drop_default:
self._exec(
_ExecDropConstraint(
table_name, column, "sys.default_constraints", schema
)
)
drop_check = kw.pop("mssql_drop_check", False)
if drop_check:
self._exec(
_ExecDropConstraint(
table_name, column, "sys.check_constraints", schema
)
)
drop_fks = kw.pop("mssql_drop_foreign_key", False)
if drop_fks:
self._exec(_ExecDropFKConstraint(table_name, column, schema))
super().drop_column(table_name, column, schema=schema, **kw)
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
if rendered_metadata_default is not None:
rendered_metadata_default = re.sub(
r"[\(\) \"\']", "", rendered_metadata_default
)
if rendered_inspector_default is not None:
# SQL Server collapses whitespace and adds arbitrary parenthesis
# within expressions. our only option is collapse all of it
rendered_inspector_default = re.sub(
r"[\(\) \"\']", "", rendered_inspector_default
)
return rendered_inspector_default != rendered_metadata_default
def _compare_identity_default(self, metadata_identity, inspector_identity):
diff, ignored, is_alter = super()._compare_identity_default(
metadata_identity, inspector_identity
)
if (
metadata_identity is None
and inspector_identity is not None
and not diff
and inspector_identity.column is not None
and inspector_identity.column.primary_key
):
# mssql reflect primary keys with autoincrement as identity
# columns. if no different attributes are present ignore them
is_alter = False
return diff, ignored, is_alter
def adjust_reflected_dialect_options(
self, reflected_object: Dict[str, Any], kind: str
) -> Dict[str, Any]:
options: Dict[str, Any]
options = reflected_object.get("dialect_options", {}).copy()
if not options.get("mssql_include"):
options.pop("mssql_include", None)
if not options.get("mssql_clustered"):
options.pop("mssql_clustered", None)
return options
class _ExecDropConstraint(Executable, ClauseElement):
inherit_cache = False
def __init__(
self,
tname: str,
colname: Union[Column[Any], str],
type_: str,
schema: Optional[str],
) -> None:
self.tname = tname
self.colname = colname
self.type_ = type_
self.schema = schema
class _ExecDropFKConstraint(Executable, ClauseElement):
inherit_cache = False
def __init__(
self, tname: str, colname: Column[Any], schema: Optional[str]
) -> None:
self.tname = tname
self.colname = colname
self.schema = schema
@compiles(_ExecDropConstraint, "mssql")
def _exec_drop_col_constraint(
element: _ExecDropConstraint, compiler: MSSQLCompiler, **kw
) -> str:
schema, tname, colname, type_ = (
element.schema,
element.tname,
element.colname,
element.type_,
)
# from http://www.mssqltips.com/sqlservertip/1425/\
# working-with-default-constraints-in-sql-server/
return """declare @const_name varchar(256)
select @const_name = QUOTENAME([name]) from %(type)s
where parent_object_id = object_id('%(schema_dot)s%(tname)s')
and col_name(parent_object_id, parent_column_id) = '%(colname)s'
exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
"type": type_,
"tname": tname,
"colname": colname,
"tname_quoted": format_table_name(compiler, tname, schema),
"schema_dot": schema + "." if schema else "",
}
@compiles(_ExecDropFKConstraint, "mssql")
def _exec_drop_col_fk_constraint(
element: _ExecDropFKConstraint, compiler: MSSQLCompiler, **kw
) -> str:
schema, tname, colname = element.schema, element.tname, element.colname
return """declare @const_name varchar(256)
select @const_name = QUOTENAME([name]) from
sys.foreign_keys fk join sys.foreign_key_columns fkc
on fk.object_id=fkc.constraint_object_id
where fkc.parent_object_id = object_id('%(schema_dot)s%(tname)s')
and col_name(fkc.parent_object_id, fkc.parent_column_id) = '%(colname)s'
exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
"tname": tname,
"colname": colname,
"tname_quoted": format_table_name(compiler, tname, schema),
"schema_dot": schema + "." if schema else "",
}
@compiles(AddColumn, "mssql")
def visit_add_column(element: AddColumn, compiler: MSDDLCompiler, **kw) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
mssql_add_column(compiler, element.column, **kw),
)
def mssql_add_column(
compiler: MSDDLCompiler, column: Column[Any], **kw
) -> str:
return "ADD %s" % compiler.get_column_specification(column, **kw)
@compiles(ColumnNullable, "mssql")
def visit_column_nullable(
element: ColumnNullable, compiler: MSDDLCompiler, **kw
) -> str:
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
format_type(compiler, element.existing_type), # type: ignore[arg-type]
"NULL" if element.nullable else "NOT NULL",
)
@compiles(ColumnDefault, "mssql")
def visit_column_default(
element: ColumnDefault, compiler: MSDDLCompiler, **kw
) -> str:
# TODO: there can also be a named constraint
# with ADD CONSTRAINT here
return "%s ADD DEFAULT %s FOR %s" % (
alter_table(compiler, element.table_name, element.schema),
format_server_default(compiler, element.default),
format_column_name(compiler, element.column_name),
)
@compiles(ColumnName, "mssql")
def visit_rename_column(
element: ColumnName, compiler: MSDDLCompiler, **kw
) -> str:
return "EXEC sp_rename '%s.%s', %s, 'COLUMN'" % (
format_table_name(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
format_column_name(compiler, element.newname),
)
@compiles(ColumnType, "mssql")
def visit_column_type(
element: ColumnType, compiler: MSDDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
format_type(compiler, element.type_),
)
@compiles(RenameTable, "mssql")
def visit_rename_table(
element: RenameTable, compiler: MSDDLCompiler, **kw
) -> str:
return "EXEC sp_rename '%s', %s" % (
format_table_name(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)

View File

@@ -0,0 +1,474 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import re
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import schema
from sqlalchemy import types as sqltypes
from .base import alter_table
from .base import AlterColumn
from .base import ColumnDefault
from .base import ColumnName
from .base import ColumnNullable
from .base import ColumnType
from .base import format_column_name
from .base import format_server_default
from .impl import DefaultImpl
from .. import util
from ..util import sqla_compat
from ..util.sqla_compat import _is_mariadb
from ..util.sqla_compat import _is_type_bound
from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler
from sqlalchemy.sql.ddl import DropConstraint
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
class MySQLImpl(DefaultImpl):
__dialect__ = "mysql"
transactional_ddl = False
type_synonyms = DefaultImpl.type_synonyms + (
{"BOOL", "TINYINT"},
{"JSON", "LONGTEXT"},
)
type_arg_extract = [r"character set ([\w\-_]+)", r"collate ([\w\-_]+)"]
def alter_column( # type:ignore[override]
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Union[_ServerDefault, Literal[False]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
existing_type: Optional[TypeEngine] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_nullable: Optional[bool] = None,
autoincrement: Optional[bool] = None,
existing_autoincrement: Optional[bool] = None,
comment: Optional[Union[str, Literal[False]]] = False,
existing_comment: Optional[str] = None,
**kw: Any,
) -> None:
if sqla_compat._server_default_is_identity(
server_default, existing_server_default
) or sqla_compat._server_default_is_computed(
server_default, existing_server_default
):
# modifying computed or identity columns is not supported
# the default will raise
super().alter_column(
table_name,
column_name,
nullable=nullable,
type_=type_,
schema=schema,
existing_type=existing_type,
existing_nullable=existing_nullable,
server_default=server_default,
existing_server_default=existing_server_default,
**kw,
)
if name is not None or self._is_mysql_allowed_functional_default(
type_ if type_ is not None else existing_type, server_default
):
self._exec(
MySQLChangeColumn(
table_name,
column_name,
schema=schema,
newname=name if name is not None else column_name,
nullable=nullable
if nullable is not None
else existing_nullable
if existing_nullable is not None
else True,
type_=type_ if type_ is not None else existing_type,
default=server_default
if server_default is not False
else existing_server_default,
autoincrement=autoincrement
if autoincrement is not None
else existing_autoincrement,
comment=comment
if comment is not False
else existing_comment,
)
)
elif (
nullable is not None
or type_ is not None
or autoincrement is not None
or comment is not False
):
self._exec(
MySQLModifyColumn(
table_name,
column_name,
schema=schema,
newname=name if name is not None else column_name,
nullable=nullable
if nullable is not None
else existing_nullable
if existing_nullable is not None
else True,
type_=type_ if type_ is not None else existing_type,
default=server_default
if server_default is not False
else existing_server_default,
autoincrement=autoincrement
if autoincrement is not None
else existing_autoincrement,
comment=comment
if comment is not False
else existing_comment,
)
)
elif server_default is not False:
self._exec(
MySQLAlterDefault(
table_name, column_name, server_default, schema=schema
)
)
def drop_constraint(
self,
const: Constraint,
) -> None:
if isinstance(const, schema.CheckConstraint) and _is_type_bound(const):
return
super().drop_constraint(const)
def _is_mysql_allowed_functional_default(
self,
type_: Optional[TypeEngine],
server_default: Union[_ServerDefault, Literal[False]],
) -> bool:
return (
type_ is not None
and type_._type_affinity is sqltypes.DateTime
and server_default is not None
)
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
# partially a workaround for SQLAlchemy issue #3023; if the
# column were created without "NOT NULL", MySQL may have added
# an implicit default of '0' which we need to skip
# TODO: this is not really covered anymore ?
if (
metadata_column.type._type_affinity is sqltypes.Integer
and inspector_column.primary_key
and not inspector_column.autoincrement
and not rendered_metadata_default
and rendered_inspector_default == "'0'"
):
return False
elif (
rendered_inspector_default
and inspector_column.type._type_affinity is sqltypes.Integer
):
rendered_inspector_default = (
re.sub(r"^'|'$", "", rendered_inspector_default)
if rendered_inspector_default is not None
else None
)
return rendered_inspector_default != rendered_metadata_default
elif (
rendered_metadata_default
and metadata_column.type._type_affinity is sqltypes.String
):
metadata_default = re.sub(r"^'|'$", "", rendered_metadata_default)
return rendered_inspector_default != f"'{metadata_default}'"
elif rendered_inspector_default and rendered_metadata_default:
# adjust for "function()" vs. "FUNCTION" as can occur particularly
# for the CURRENT_TIMESTAMP function on newer MariaDB versions
# SQLAlchemy MySQL dialect bundles ON UPDATE into the server
# default; adjust for this possibly being present.
onupdate_ins = re.match(
r"(.*) (on update.*?)(?:\(\))?$",
rendered_inspector_default.lower(),
)
onupdate_met = re.match(
r"(.*) (on update.*?)(?:\(\))?$",
rendered_metadata_default.lower(),
)
if onupdate_ins:
if not onupdate_met:
return True
elif onupdate_ins.group(2) != onupdate_met.group(2):
return True
rendered_inspector_default = onupdate_ins.group(1)
rendered_metadata_default = onupdate_met.group(1)
return re.sub(
r"(.*?)(?:\(\))?$", r"\1", rendered_inspector_default.lower()
) != re.sub(
r"(.*?)(?:\(\))?$", r"\1", rendered_metadata_default.lower()
)
else:
return rendered_inspector_default != rendered_metadata_default
def correct_for_autogen_constraints(
self,
conn_unique_constraints,
conn_indexes,
metadata_unique_constraints,
metadata_indexes,
):
# TODO: if SQLA 1.0, make use of "duplicates_index"
# metadata
removed = set()
for idx in list(conn_indexes):
if idx.unique:
continue
# MySQL puts implicit indexes on FK columns, even if
# composite and even if MyISAM, so can't check this too easily.
# the name of the index may be the column name or it may
# be the name of the FK constraint.
for col in idx.columns:
if idx.name == col.name:
conn_indexes.remove(idx)
removed.add(idx.name)
break
for fk in col.foreign_keys:
if fk.name == idx.name:
conn_indexes.remove(idx)
removed.add(idx.name)
break
if idx.name in removed:
break
# then remove indexes from the "metadata_indexes"
# that we've removed from reflected, otherwise they come out
# as adds (see #202)
for idx in list(metadata_indexes):
if idx.name in removed:
metadata_indexes.remove(idx)
def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
conn_fk_by_sig = {
self._create_reflected_constraint_sig(fk).unnamed_no_options: fk
for fk in conn_fks
}
metadata_fk_by_sig = {
self._create_metadata_constraint_sig(fk).unnamed_no_options: fk
for fk in metadata_fks
}
for sig in set(conn_fk_by_sig).intersection(metadata_fk_by_sig):
mdfk = metadata_fk_by_sig[sig]
cnfk = conn_fk_by_sig[sig]
# MySQL considers RESTRICT to be the default and doesn't
# report on it. if the model has explicit RESTRICT and
# the conn FK has None, set it to RESTRICT
if (
mdfk.ondelete is not None
and mdfk.ondelete.lower() == "restrict"
and cnfk.ondelete is None
):
cnfk.ondelete = "RESTRICT"
if (
mdfk.onupdate is not None
and mdfk.onupdate.lower() == "restrict"
and cnfk.onupdate is None
):
cnfk.onupdate = "RESTRICT"
class MariaDBImpl(MySQLImpl):
__dialect__ = "mariadb"
class MySQLAlterDefault(AlterColumn):
def __init__(
self,
name: str,
column_name: str,
default: _ServerDefault,
schema: Optional[str] = None,
) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.default = default
class MySQLChangeColumn(AlterColumn):
def __init__(
self,
name: str,
column_name: str,
schema: Optional[str] = None,
newname: Optional[str] = None,
type_: Optional[TypeEngine] = None,
nullable: Optional[bool] = None,
default: Optional[Union[_ServerDefault, Literal[False]]] = False,
autoincrement: Optional[bool] = None,
comment: Optional[Union[str, Literal[False]]] = False,
) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.nullable = nullable
self.newname = newname
self.default = default
self.autoincrement = autoincrement
self.comment = comment
if type_ is None:
raise util.CommandError(
"All MySQL CHANGE/MODIFY COLUMN operations "
"require the existing type."
)
self.type_ = sqltypes.to_instance(type_)
class MySQLModifyColumn(MySQLChangeColumn):
pass
@compiles(ColumnNullable, "mysql", "mariadb")
@compiles(ColumnName, "mysql", "mariadb")
@compiles(ColumnDefault, "mysql", "mariadb")
@compiles(ColumnType, "mysql", "mariadb")
def _mysql_doesnt_support_individual(element, compiler, **kw):
raise NotImplementedError(
"Individual alter column constructs not supported by MySQL"
)
@compiles(MySQLAlterDefault, "mysql", "mariadb")
def _mysql_alter_default(
element: MySQLAlterDefault, compiler: MySQLDDLCompiler, **kw
) -> str:
return "%s ALTER COLUMN %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
"SET DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DROP DEFAULT",
)
@compiles(MySQLModifyColumn, "mysql", "mariadb")
def _mysql_modify_column(
element: MySQLModifyColumn, compiler: MySQLDDLCompiler, **kw
) -> str:
return "%s MODIFY %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
_mysql_colspec(
compiler,
nullable=element.nullable,
server_default=element.default,
type_=element.type_,
autoincrement=element.autoincrement,
comment=element.comment,
),
)
@compiles(MySQLChangeColumn, "mysql", "mariadb")
def _mysql_change_column(
element: MySQLChangeColumn, compiler: MySQLDDLCompiler, **kw
) -> str:
return "%s CHANGE %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
format_column_name(compiler, element.newname),
_mysql_colspec(
compiler,
nullable=element.nullable,
server_default=element.default,
type_=element.type_,
autoincrement=element.autoincrement,
comment=element.comment,
),
)
def _mysql_colspec(
compiler: MySQLDDLCompiler,
nullable: Optional[bool],
server_default: Optional[Union[_ServerDefault, Literal[False]]],
type_: TypeEngine,
autoincrement: Optional[bool],
comment: Optional[Union[str, Literal[False]]],
) -> str:
spec = "%s %s" % (
compiler.dialect.type_compiler.process(type_),
"NULL" if nullable else "NOT NULL",
)
if autoincrement:
spec += " AUTO_INCREMENT"
if server_default is not False and server_default is not None:
spec += " DEFAULT %s" % format_server_default(compiler, server_default)
if comment:
spec += " COMMENT %s" % compiler.sql_compiler.render_literal_value(
comment, sqltypes.String()
)
return spec
@compiles(schema.DropConstraint, "mysql", "mariadb")
def _mysql_drop_constraint(
element: DropConstraint, compiler: MySQLDDLCompiler, **kw
) -> str:
"""Redefine SQLAlchemy's drop constraint to
raise errors for invalid constraint type."""
constraint = element.element
if isinstance(
constraint,
(
schema.ForeignKeyConstraint,
schema.PrimaryKeyConstraint,
schema.UniqueConstraint,
),
):
assert not kw
return compiler.visit_drop_constraint(element)
elif isinstance(constraint, schema.CheckConstraint):
# note that SQLAlchemy as of 1.2 does not yet support
# DROP CONSTRAINT for MySQL/MariaDB, so we implement fully
# here.
if _is_mariadb(compiler.dialect):
return "ALTER TABLE %s DROP CONSTRAINT %s" % (
compiler.preparer.format_table(constraint.table),
compiler.preparer.format_constraint(constraint),
)
else:
return "ALTER TABLE %s DROP CHECK %s" % (
compiler.preparer.format_table(constraint.table),
compiler.preparer.format_constraint(constraint),
)
else:
raise NotImplementedError(
"No generic 'DROP CONSTRAINT' in MySQL - "
"please specify constraint type"
)

View File

@@ -0,0 +1,200 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import re
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from sqlalchemy.sql import sqltypes
from .base import AddColumn
from .base import alter_table
from .base import ColumnComment
from .base import ColumnDefault
from .base import ColumnName
from .base import ColumnNullable
from .base import ColumnType
from .base import format_column_name
from .base import format_server_default
from .base import format_table_name
from .base import format_type
from .base import IdentityColumnDefault
from .base import RenameTable
from .impl import DefaultImpl
from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.sql.schema import Column
class OracleImpl(DefaultImpl):
__dialect__ = "oracle"
transactional_ddl = False
batch_separator = "/"
command_terminator = ""
type_synonyms = DefaultImpl.type_synonyms + (
{"VARCHAR", "VARCHAR2"},
{"BIGINT", "INTEGER", "SMALLINT", "DECIMAL", "NUMERIC", "NUMBER"},
{"DOUBLE", "FLOAT", "DOUBLE_PRECISION"},
)
identity_attrs_ignore = ()
def __init__(self, *arg, **kw) -> None:
super().__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
"oracle_batch_separator", self.batch_separator
)
def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]:
result = super()._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
return result
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
if rendered_metadata_default is not None:
rendered_metadata_default = re.sub(
r"^\((.+)\)$", r"\1", rendered_metadata_default
)
rendered_metadata_default = re.sub(
r"^\"?'(.+)'\"?$", r"\1", rendered_metadata_default
)
if rendered_inspector_default is not None:
rendered_inspector_default = re.sub(
r"^\((.+)\)$", r"\1", rendered_inspector_default
)
rendered_inspector_default = re.sub(
r"^\"?'(.+)'\"?$", r"\1", rendered_inspector_default
)
rendered_inspector_default = rendered_inspector_default.strip()
return rendered_inspector_default != rendered_metadata_default
def emit_begin(self) -> None:
self._exec("SET TRANSACTION READ WRITE")
def emit_commit(self) -> None:
self._exec("COMMIT")
@compiles(AddColumn, "oracle")
def visit_add_column(
element: AddColumn, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
add_column(compiler, element.column, **kw),
)
@compiles(ColumnNullable, "oracle")
def visit_column_nullable(
element: ColumnNullable, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"NULL" if element.nullable else "NOT NULL",
)
@compiles(ColumnType, "oracle")
def visit_column_type(
element: ColumnType, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"%s" % format_type(compiler, element.type_),
)
@compiles(ColumnName, "oracle")
def visit_column_name(
element: ColumnName, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s RENAME COLUMN %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
format_column_name(compiler, element.newname),
)
@compiles(ColumnDefault, "oracle")
def visit_column_default(
element: ColumnDefault, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DEFAULT NULL",
)
@compiles(ColumnComment, "oracle")
def visit_column_comment(
element: ColumnComment, compiler: OracleDDLCompiler, **kw
) -> str:
ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
comment = compiler.sql_compiler.render_literal_value(
(element.comment if element.comment is not None else ""),
sqltypes.String(),
)
return ddl.format(
table_name=element.table_name,
column_name=element.column_name,
comment=comment,
)
@compiles(RenameTable, "oracle")
def visit_rename_table(
element: RenameTable, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)
def alter_column(compiler: OracleDDLCompiler, name: str) -> str:
return "MODIFY %s" % format_column_name(compiler, name)
def add_column(compiler: OracleDDLCompiler, column: Column[Any], **kw) -> str:
return "ADD %s" % compiler.get_column_specification(column, **kw)
@compiles(IdentityColumnDefault, "oracle")
def visit_identity_column(
element: IdentityColumnDefault, compiler: OracleDDLCompiler, **kw
):
text = "%s %s " % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
)
if element.default is None:
# drop identity
text += "DROP IDENTITY"
return text
else:
text += compiler.visit_identity_column(element.default)
return text

View File

@@ -0,0 +1,848 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import logging
import re
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import Column
from sqlalchemy import literal_column
from sqlalchemy import Numeric
from sqlalchemy import text
from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql import BIGINT
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import INTEGER
from sqlalchemy.schema import CreateIndex
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.functions import FunctionElement
from sqlalchemy.types import NULLTYPE
from .base import alter_column
from .base import alter_table
from .base import AlterColumn
from .base import ColumnComment
from .base import format_column_name
from .base import format_table_name
from .base import format_type
from .base import IdentityColumnDefault
from .base import RenameTable
from .impl import ComparisonResult
from .impl import DefaultImpl
from .. import util
from ..autogenerate import render
from ..operations import ops
from ..operations import schemaobj
from ..operations.base import BatchOperations
from ..operations.base import Operations
from ..util import sqla_compat
from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy import Index
from sqlalchemy import UniqueConstraint
from sqlalchemy.dialects.postgresql.array import ARRAY
from sqlalchemy.dialects.postgresql.base import PGDDLCompiler
from sqlalchemy.dialects.postgresql.hstore import HSTORE
from sqlalchemy.dialects.postgresql.json import JSON
from sqlalchemy.dialects.postgresql.json import JSONB
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
from ..autogenerate.api import AutogenContext
from ..autogenerate.render import _f_name
from ..runtime.migration import MigrationContext
log = logging.getLogger(__name__)
class PostgresqlImpl(DefaultImpl):
__dialect__ = "postgresql"
transactional_ddl = True
type_synonyms = DefaultImpl.type_synonyms + (
{"FLOAT", "DOUBLE PRECISION"},
)
def create_index(self, index: Index, **kw: Any) -> None:
# this likely defaults to None if not present, so get()
# should normally not return the default value. being
# defensive in any case
postgresql_include = index.kwargs.get("postgresql_include", None) or ()
for col in postgresql_include:
if col not in index.table.c: # type: ignore[union-attr]
index.table.append_column( # type: ignore[union-attr]
Column(col, sqltypes.NullType)
)
self._exec(CreateIndex(index, **kw))
def prep_table_for_batch(self, batch_impl, table):
for constraint in table.constraints:
if (
constraint.name is not None
and constraint.name in batch_impl.named_constraints
):
self.drop_constraint(constraint)
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
# don't do defaults for SERIAL columns
if (
metadata_column.primary_key
and metadata_column is metadata_column.table._autoincrement_column
):
return False
conn_col_default = rendered_inspector_default
defaults_equal = conn_col_default == rendered_metadata_default
if defaults_equal:
return False
if None in (
conn_col_default,
rendered_metadata_default,
metadata_column.server_default,
):
return not defaults_equal
metadata_default = metadata_column.server_default.arg
if isinstance(metadata_default, str):
if not isinstance(inspector_column.type, Numeric):
metadata_default = re.sub(r"^'|'$", "", metadata_default)
metadata_default = f"'{metadata_default}'"
metadata_default = literal_column(metadata_default)
# run a real compare against the server
conn = self.connection
assert conn is not None
return not conn.scalar(
sqla_compat._select(
literal_column(conn_col_default) == metadata_default
)
)
def alter_column( # type:ignore[override]
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Union[_ServerDefault, Literal[False]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
autoincrement: Optional[bool] = None,
existing_type: Optional[TypeEngine] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_nullable: Optional[bool] = None,
existing_autoincrement: Optional[bool] = None,
**kw: Any,
) -> None:
using = kw.pop("postgresql_using", None)
if using is not None and type_ is None:
raise util.CommandError(
"postgresql_using must be used with the type_ parameter"
)
if type_ is not None:
self._exec(
PostgresqlColumnType(
table_name,
column_name,
type_,
schema=schema,
using=using,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
)
)
super().alter_column(
table_name,
column_name,
nullable=nullable,
server_default=server_default,
name=name,
schema=schema,
autoincrement=autoincrement,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_autoincrement=existing_autoincrement,
**kw,
)
def autogen_column_reflect(self, inspector, table, column_info):
if column_info.get("default") and isinstance(
column_info["type"], (INTEGER, BIGINT)
):
seq_match = re.match(
r"nextval\('(.+?)'::regclass\)", column_info["default"]
)
if seq_match:
info = sqla_compat._exec_on_inspector(
inspector,
text(
"select c.relname, a.attname "
"from pg_class as c join "
"pg_depend d on d.objid=c.oid and "
"d.classid='pg_class'::regclass and "
"d.refclassid='pg_class'::regclass "
"join pg_class t on t.oid=d.refobjid "
"join pg_attribute a on a.attrelid=t.oid and "
"a.attnum=d.refobjsubid "
"where c.relkind='S' and c.relname=:seqname"
),
seqname=seq_match.group(1),
).first()
if info:
seqname, colname = info
if colname == column_info["name"]:
log.info(
"Detected sequence named '%s' as "
"owned by integer column '%s(%s)', "
"assuming SERIAL and omitting",
seqname,
table.name,
colname,
)
# sequence, and the owner is this column,
# its a SERIAL - whack it!
del column_info["default"]
def correct_for_autogen_constraints(
self,
conn_unique_constraints,
conn_indexes,
metadata_unique_constraints,
metadata_indexes,
):
doubled_constraints = {
index
for index in conn_indexes
if index.info.get("duplicates_constraint")
}
for ix in doubled_constraints:
conn_indexes.remove(ix)
if not sqla_compat.sqla_2:
self._skip_functional_indexes(metadata_indexes, conn_indexes)
# pg behavior regarding modifiers
# | # | compiled sql | returned sql | regexp. group is removed |
# | - | ---------------- | -----------------| ------------------------ |
# | 1 | nulls first | nulls first | - |
# | 2 | nulls last | | (?<! desc)( nulls last)$ |
# | 3 | asc | | ( asc)$ |
# | 4 | asc nulls first | nulls first | ( asc) nulls first$ |
# | 5 | asc nulls last | | ( asc nulls last)$ |
# | 6 | desc | desc | - |
# | 7 | desc nulls first | desc | desc( nulls first)$ |
# | 8 | desc nulls last | desc nulls last | - |
_default_modifiers_re = ( # order of case 2 and 5 matters
re.compile("( asc nulls last)$"), # case 5
re.compile("(?<! desc)( nulls last)$"), # case 2
re.compile("( asc)$"), # case 3
re.compile("( asc) nulls first$"), # case 4
re.compile(" desc( nulls first)$"), # case 7
)
def _cleanup_index_expr(self, index: Index, expr: str) -> str:
expr = expr.lower().replace('"', "").replace("'", "")
if index.table is not None:
# should not be needed, since include_table=False is in compile
expr = expr.replace(f"{index.table.name.lower()}.", "")
if "::" in expr:
# strip :: cast. types can have spaces in them
expr = re.sub(r"(::[\w ]+\w)", "", expr)
while expr and expr[0] == "(" and expr[-1] == ")":
expr = expr[1:-1]
# NOTE: when parsing the connection expression this cleanup could
# be skipped
for rs in self._default_modifiers_re:
if match := rs.search(expr):
start, end = match.span(1)
expr = expr[:start] + expr[end:]
break
while expr and expr[0] == "(" and expr[-1] == ")":
expr = expr[1:-1]
# strip casts
cast_re = re.compile(r"cast\s*\(")
if cast_re.match(expr):
expr = cast_re.sub("", expr)
# remove the as type
expr = re.sub(r"as\s+[^)]+\)", "", expr)
# remove spaces
expr = expr.replace(" ", "")
return expr
def _dialect_options(
self, item: Union[Index, UniqueConstraint]
) -> Tuple[Any, ...]:
# only the positive case is returned by sqlalchemy reflection so
# None and False are threated the same
if item.dialect_kwargs.get("postgresql_nulls_not_distinct"):
return ("nulls_not_distinct",)
return ()
def compare_indexes(
self,
metadata_index: Index,
reflected_index: Index,
) -> ComparisonResult:
msg = []
unique_msg = self._compare_index_unique(
metadata_index, reflected_index
)
if unique_msg:
msg.append(unique_msg)
m_exprs = metadata_index.expressions
r_exprs = reflected_index.expressions
if len(m_exprs) != len(r_exprs):
msg.append(f"expression number {len(r_exprs)} to {len(m_exprs)}")
if msg:
# no point going further, return early
return ComparisonResult.Different(msg)
skip = []
for pos, (m_e, r_e) in enumerate(zip(m_exprs, r_exprs), 1):
m_compile = self._compile_element(m_e)
m_text = self._cleanup_index_expr(metadata_index, m_compile)
# print(f"META ORIG: {m_compile!r} CLEANUP: {m_text!r}")
r_compile = self._compile_element(r_e)
r_text = self._cleanup_index_expr(metadata_index, r_compile)
# print(f"CONN ORIG: {r_compile!r} CLEANUP: {r_text!r}")
if m_text == r_text:
continue # expressions these are equal
elif m_compile.strip().endswith("_ops") and (
" " in m_compile or ")" in m_compile # is an expression
):
skip.append(
f"expression #{pos} {m_compile!r} detected "
"as including operator clause."
)
util.warn(
f"Expression #{pos} {m_compile!r} in index "
f"{reflected_index.name!r} detected to include "
"an operator clause. Expression compare cannot proceed. "
"Please move the operator clause to the "
"``postgresql_ops`` dict to enable proper compare "
"of the index expressions: "
"https://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#operator-classes", # noqa: E501
)
else:
msg.append(f"expression #{pos} {r_compile!r} to {m_compile!r}")
m_options = self._dialect_options(metadata_index)
r_options = self._dialect_options(reflected_index)
if m_options != r_options:
msg.extend(f"options {r_options} to {m_options}")
if msg:
return ComparisonResult.Different(msg)
elif skip:
# if there are other changes detected don't skip the index
return ComparisonResult.Skip(skip)
else:
return ComparisonResult.Equal()
def compare_unique_constraint(
self,
metadata_constraint: UniqueConstraint,
reflected_constraint: UniqueConstraint,
) -> ComparisonResult:
metadata_tup = self._create_metadata_constraint_sig(
metadata_constraint
)
reflected_tup = self._create_reflected_constraint_sig(
reflected_constraint
)
meta_sig = metadata_tup.unnamed
conn_sig = reflected_tup.unnamed
if conn_sig != meta_sig:
return ComparisonResult.Different(
f"expression {conn_sig} to {meta_sig}"
)
metadata_do = self._dialect_options(metadata_tup.const)
conn_do = self._dialect_options(reflected_tup.const)
if metadata_do != conn_do:
return ComparisonResult.Different(
f"expression {conn_do} to {metadata_do}"
)
return ComparisonResult.Equal()
def adjust_reflected_dialect_options(
self, reflected_options: Dict[str, Any], kind: str
) -> Dict[str, Any]:
options: Dict[str, Any]
options = reflected_options.get("dialect_options", {}).copy()
if not options.get("postgresql_include"):
options.pop("postgresql_include", None)
return options
def _compile_element(self, element: Union[ClauseElement, str]) -> str:
if isinstance(element, str):
return element
return element.compile(
dialect=self.dialect,
compile_kwargs={"literal_binds": True, "include_table": False},
).string
def render_ddl_sql_expr(
self,
expr: ClauseElement,
is_server_default: bool = False,
is_index: bool = False,
**kw: Any,
) -> str:
"""Render a SQL expression that is typically a server default,
index expression, etc.
"""
# apply self_group to index expressions;
# see https://github.com/sqlalchemy/sqlalchemy/blob/
# 82fa95cfce070fab401d020c6e6e4a6a96cc2578/
# lib/sqlalchemy/dialects/postgresql/base.py#L2261
if is_index and not isinstance(expr, ColumnClause):
expr = expr.self_group()
return super().render_ddl_sql_expr(
expr, is_server_default=is_server_default, is_index=is_index, **kw
)
def render_type(
self, type_: TypeEngine, autogen_context: AutogenContext
) -> Union[str, Literal[False]]:
mod = type(type_).__module__
if not mod.startswith("sqlalchemy.dialects.postgresql"):
return False
if hasattr(self, "_render_%s_type" % type_.__visit_name__):
meth = getattr(self, "_render_%s_type" % type_.__visit_name__)
return meth(type_, autogen_context)
return False
def _render_HSTORE_type(
self, type_: HSTORE, autogen_context: AutogenContext
) -> str:
return cast(
str,
render._render_type_w_subtype(
type_, autogen_context, "text_type", r"(.+?\(.*text_type=)"
),
)
def _render_ARRAY_type(
self, type_: ARRAY, autogen_context: AutogenContext
) -> str:
return cast(
str,
render._render_type_w_subtype(
type_, autogen_context, "item_type", r"(.+?\()"
),
)
def _render_JSON_type(
self, type_: JSON, autogen_context: AutogenContext
) -> str:
return cast(
str,
render._render_type_w_subtype(
type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
),
)
def _render_JSONB_type(
self, type_: JSONB, autogen_context: AutogenContext
) -> str:
return cast(
str,
render._render_type_w_subtype(
type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
),
)
class PostgresqlColumnType(AlterColumn):
def __init__(
self, name: str, column_name: str, type_: TypeEngine, **kw
) -> None:
using = kw.pop("using", None)
super().__init__(name, column_name, **kw)
self.type_ = sqltypes.to_instance(type_)
self.using = using
@compiles(RenameTable, "postgresql")
def visit_rename_table(
element: RenameTable, compiler: PGDDLCompiler, **kw
) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)
@compiles(PostgresqlColumnType, "postgresql")
def visit_column_type(
element: PostgresqlColumnType, compiler: PGDDLCompiler, **kw
) -> str:
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"TYPE %s" % format_type(compiler, element.type_),
"USING %s" % element.using if element.using else "",
)
@compiles(ColumnComment, "postgresql")
def visit_column_comment(
element: ColumnComment, compiler: PGDDLCompiler, **kw
) -> str:
ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
comment = (
compiler.sql_compiler.render_literal_value(
element.comment, sqltypes.String()
)
if element.comment is not None
else "NULL"
)
return ddl.format(
table_name=format_table_name(
compiler, element.table_name, element.schema
),
column_name=format_column_name(compiler, element.column_name),
comment=comment,
)
@compiles(IdentityColumnDefault, "postgresql")
def visit_identity_column(
element: IdentityColumnDefault, compiler: PGDDLCompiler, **kw
):
text = "%s %s " % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
)
if element.default is None:
# drop identity
text += "DROP IDENTITY"
return text
elif element.existing_server_default is None:
# add identity options
text += "ADD "
text += compiler.visit_identity_column(element.default)
return text
else:
# alter identity
diff, _, _ = element.impl._compare_identity_default(
element.default, element.existing_server_default
)
identity = element.default
for attr in sorted(diff):
if attr == "always":
text += "SET GENERATED %s " % (
"ALWAYS" if identity.always else "BY DEFAULT"
)
else:
text += "SET %s " % compiler.get_identity_options(
sqla_compat.Identity(**{attr: getattr(identity, attr)})
)
return text
@Operations.register_operation("create_exclude_constraint")
@BatchOperations.register_operation(
"create_exclude_constraint", "batch_create_exclude_constraint"
)
@ops.AddConstraintOp.register_add_constraint("exclude_constraint")
class CreateExcludeConstraintOp(ops.AddConstraintOp):
"""Represent a create exclude constraint operation."""
constraint_type = "exclude"
def __init__(
self,
constraint_name: sqla_compat._ConstraintName,
table_name: Union[str, quoted_name],
elements: Union[
Sequence[Tuple[str, str]],
Sequence[Tuple[ColumnClause[Any], str]],
],
where: Optional[Union[ColumnElement[bool], str]] = None,
schema: Optional[str] = None,
_orig_constraint: Optional[ExcludeConstraint] = None,
**kw,
) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
self.elements = elements
self.where = where
self.schema = schema
self._orig_constraint = _orig_constraint
self.kw = kw
@classmethod
def from_constraint( # type:ignore[override]
cls, constraint: ExcludeConstraint
) -> CreateExcludeConstraintOp:
constraint_table = sqla_compat._table_for_constraint(constraint)
return cls(
constraint.name,
constraint_table.name,
[ # type: ignore
(expr, op) for expr, name, op in constraint._render_exprs
],
where=cast("ColumnElement[bool] | None", constraint.where),
schema=constraint_table.schema,
_orig_constraint=constraint,
deferrable=constraint.deferrable,
initially=constraint.initially,
using=constraint.using,
)
def to_constraint(
self, migration_context: Optional[MigrationContext] = None
) -> ExcludeConstraint:
if self._orig_constraint is not None:
return self._orig_constraint
schema_obj = schemaobj.SchemaObjects(migration_context)
t = schema_obj.table(self.table_name, schema=self.schema)
excl = ExcludeConstraint(
*self.elements,
name=self.constraint_name,
where=self.where,
**self.kw,
)
for (
expr,
name,
oper,
) in excl._render_exprs:
t.append_column(Column(name, NULLTYPE))
t.append_constraint(excl)
return excl
@classmethod
def create_exclude_constraint(
cls,
operations: Operations,
constraint_name: str,
table_name: str,
*elements: Any,
**kw: Any,
) -> Optional[Table]:
"""Issue an alter to create an EXCLUDE constraint using the
current migration context.
.. note:: This method is Postgresql specific, and additionally
requires at least SQLAlchemy 1.0.
e.g.::
from alembic import op
op.create_exclude_constraint(
"user_excl",
"user",
("period", "&&"),
("group", "="),
where=("group != 'some group'"),
)
Note that the expressions work the same way as that of
the ``ExcludeConstraint`` object itself; if plain strings are
passed, quoting rules must be applied manually.
:param name: Name of the constraint.
:param table_name: String name of the source table.
:param elements: exclude conditions.
:param where: SQL expression or SQL string with optional WHERE
clause.
:param deferrable: optional bool. If set, emit DEFERRABLE or
NOT DEFERRABLE when issuing DDL for this constraint.
:param initially: optional string. If set, emit INITIALLY <value>
when issuing DDL for this constraint.
:param schema: Optional schema name to operate within.
"""
op = cls(constraint_name, table_name, elements, **kw)
return operations.invoke(op)
@classmethod
def batch_create_exclude_constraint(
cls,
operations: BatchOperations,
constraint_name: str,
*elements: Any,
**kw: Any,
) -> Optional[Table]:
"""Issue a "create exclude constraint" instruction using the
current batch migration context.
.. note:: This method is Postgresql specific, and additionally
requires at least SQLAlchemy 1.0.
.. seealso::
:meth:`.Operations.create_exclude_constraint`
"""
kw["schema"] = operations.impl.schema
op = cls(constraint_name, operations.impl.table_name, elements, **kw)
return operations.invoke(op)
@render.renderers.dispatch_for(CreateExcludeConstraintOp)
def _add_exclude_constraint(
autogen_context: AutogenContext, op: CreateExcludeConstraintOp
) -> str:
return _exclude_constraint(op.to_constraint(), autogen_context, alter=True)
@render._constraint_renderers.dispatch_for(ExcludeConstraint)
def _render_inline_exclude_constraint(
constraint: ExcludeConstraint,
autogen_context: AutogenContext,
namespace_metadata: MetaData,
) -> str:
rendered = render._user_defined_render(
"exclude", constraint, autogen_context
)
if rendered is not False:
return rendered
return _exclude_constraint(constraint, autogen_context, False)
def _postgresql_autogenerate_prefix(autogen_context: AutogenContext) -> str:
imports = autogen_context.imports
if imports is not None:
imports.add("from sqlalchemy.dialects import postgresql")
return "postgresql."
def _exclude_constraint(
constraint: ExcludeConstraint,
autogen_context: AutogenContext,
alter: bool,
) -> str:
opts: List[Tuple[str, Union[quoted_name, str, _f_name, None]]] = []
has_batch = autogen_context._has_batch
if constraint.deferrable:
opts.append(("deferrable", str(constraint.deferrable)))
if constraint.initially:
opts.append(("initially", str(constraint.initially)))
if constraint.using:
opts.append(("using", str(constraint.using)))
if not has_batch and alter and constraint.table.schema:
opts.append(("schema", render._ident(constraint.table.schema)))
if not alter and constraint.name:
opts.append(
("name", render._render_gen_name(autogen_context, constraint.name))
)
def do_expr_where_opts():
args = [
"(%s, %r)"
% (
_render_potential_column(
sqltext, # type:ignore[arg-type]
autogen_context,
),
opstring,
)
for sqltext, name, opstring in constraint._render_exprs
]
if constraint.where is not None:
args.append(
"where=%s"
% render._render_potential_expr(
constraint.where, autogen_context
)
)
args.extend(["%s=%r" % (k, v) for k, v in opts])
return args
if alter:
args = [
repr(render._render_gen_name(autogen_context, constraint.name))
]
if not has_batch:
args += [repr(render._ident(constraint.table.name))]
args.extend(do_expr_where_opts())
return "%(prefix)screate_exclude_constraint(%(args)s)" % {
"prefix": render._alembic_autogenerate_prefix(autogen_context),
"args": ", ".join(args),
}
else:
args = do_expr_where_opts()
return "%(prefix)sExcludeConstraint(%(args)s)" % {
"prefix": _postgresql_autogenerate_prefix(autogen_context),
"args": ", ".join(args),
}
def _render_potential_column(
value: Union[
ColumnClause[Any], Column[Any], TextClause, FunctionElement[Any]
],
autogen_context: AutogenContext,
) -> str:
if isinstance(value, ColumnClause):
if value.is_literal:
# like literal_column("int8range(from, to)") in ExcludeConstraint
template = "%(prefix)sliteral_column(%(name)r)"
else:
template = "%(prefix)scolumn(%(name)r)"
return template % {
"prefix": render._sqlalchemy_autogenerate_prefix(autogen_context),
"name": value.name,
}
else:
return render._render_potential_expr(
value,
autogen_context,
wrap_in_text=isinstance(value, (TextClause, FunctionElement)),
)

View File

@@ -0,0 +1,225 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import re
from typing import Any
from typing import Dict
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import cast
from sqlalchemy import JSON
from sqlalchemy import schema
from sqlalchemy import sql
from .base import alter_table
from .base import format_table_name
from .base import RenameTable
from .impl import DefaultImpl
from .. import util
from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.compiler import DDLCompiler
from sqlalchemy.sql.elements import Cast
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.type_api import TypeEngine
from ..operations.batch import BatchOperationsImpl
class SQLiteImpl(DefaultImpl):
__dialect__ = "sqlite"
transactional_ddl = False
"""SQLite supports transactional DDL, but pysqlite does not:
see: http://bugs.python.org/issue10740
"""
def requires_recreate_in_batch(
self, batch_op: BatchOperationsImpl
) -> bool:
"""Return True if the given :class:`.BatchOperationsImpl`
would need the table to be recreated and copied in order to
proceed.
Normally, only returns True on SQLite when operations other
than add_column are present.
"""
for op in batch_op.batch:
if op[0] == "add_column":
col = op[1][1]
if isinstance(
col.server_default, schema.DefaultClause
) and isinstance(col.server_default.arg, sql.ClauseElement):
return True
elif (
isinstance(col.server_default, util.sqla_compat.Computed)
and col.server_default.persisted
):
return True
elif op[0] not in ("create_index", "drop_index"):
return True
else:
return False
def add_constraint(self, const: Constraint):
# attempt to distinguish between an
# auto-gen constraint and an explicit one
if const._create_rule is None:
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect. "
"Please refer to the batch mode feature which allows for "
"SQLite migrations using a copy-and-move strategy."
)
elif const._create_rule(self):
util.warn(
"Skipping unsupported ALTER for "
"creation of implicit constraint. "
"Please refer to the batch mode feature which allows for "
"SQLite migrations using a copy-and-move strategy."
)
def drop_constraint(self, const: Constraint):
if const._create_rule is None:
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect. "
"Please refer to the batch mode feature which allows for "
"SQLite migrations using a copy-and-move strategy."
)
def compare_server_default(
self,
inspector_column: Column[Any],
metadata_column: Column[Any],
rendered_metadata_default: Optional[str],
rendered_inspector_default: Optional[str],
) -> bool:
if rendered_metadata_default is not None:
rendered_metadata_default = re.sub(
r"^\((.+)\)$", r"\1", rendered_metadata_default
)
rendered_metadata_default = re.sub(
r"^\"?'(.+)'\"?$", r"\1", rendered_metadata_default
)
if rendered_inspector_default is not None:
rendered_inspector_default = re.sub(
r"^\((.+)\)$", r"\1", rendered_inspector_default
)
rendered_inspector_default = re.sub(
r"^\"?'(.+)'\"?$", r"\1", rendered_inspector_default
)
return rendered_inspector_default != rendered_metadata_default
def _guess_if_default_is_unparenthesized_sql_expr(
self, expr: Optional[str]
) -> bool:
"""Determine if a server default is a SQL expression or a constant.
There are too many assertions that expect server defaults to round-trip
identically without parenthesis added so we will add parens only in
very specific cases.
"""
if not expr:
return False
elif re.match(r"^[0-9\.]$", expr):
return False
elif re.match(r"^'.+'$", expr):
return False
elif re.match(r"^\(.+\)$", expr):
return False
else:
return True
def autogen_column_reflect(
self,
inspector: Inspector,
table: Table,
column_info: Dict[str, Any],
) -> None:
# SQLite expression defaults require parenthesis when sent
# as DDL
if self._guess_if_default_is_unparenthesized_sql_expr(
column_info.get("default", None)
):
column_info["default"] = "(%s)" % (column_info["default"],)
def render_ddl_sql_expr(
self, expr: ClauseElement, is_server_default: bool = False, **kw
) -> str:
# SQLite expression defaults require parenthesis when sent
# as DDL
str_expr = super().render_ddl_sql_expr(
expr, is_server_default=is_server_default, **kw
)
if (
is_server_default
and self._guess_if_default_is_unparenthesized_sql_expr(str_expr)
):
str_expr = "(%s)" % (str_expr,)
return str_expr
def cast_for_batch_migrate(
self,
existing: Column[Any],
existing_transfer: Dict[str, Union[TypeEngine, Cast]],
new_type: TypeEngine,
) -> None:
if (
existing.type._type_affinity is not new_type._type_affinity
and not isinstance(new_type, JSON)
):
existing_transfer["expr"] = cast(
existing_transfer["expr"], new_type
)
def correct_for_autogen_constraints(
self,
conn_unique_constraints,
conn_indexes,
metadata_unique_constraints,
metadata_indexes,
):
self._skip_functional_indexes(metadata_indexes, conn_indexes)
@compiles(RenameTable, "sqlite")
def visit_rename_table(
element: RenameTable, compiler: DDLCompiler, **kw
) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)
# @compiles(AddColumn, 'sqlite')
# def visit_add_column(element, compiler, **kw):
# return "%s %s" % (
# alter_table(compiler, element.table_name, element.schema),
# add_column(compiler, element.column, **kw)
# )
# def add_column(compiler, column, **kw):
# text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
# need to modify SQLAlchemy so that the CHECK associated with a Boolean
# or Enum gets placed as part of the column constraints, not the Table
# see ticket 98
# for const in column.constraints:
# text += compiler.process(AddConstraint(const))
# return text