315 lines
9.3 KiB
Python
315 lines
9.3 KiB
Python
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
|
|
# mypy: no-warn-return-any, allow-any-generics
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
from typing import cast
|
|
from typing import Collection
|
|
from typing import TYPE_CHECKING
|
|
|
|
from sqlalchemy.sql.elements import conv
|
|
from typing_extensions import Self
|
|
|
|
from ...util import sqla_compat
|
|
|
|
if TYPE_CHECKING:
|
|
from sqlalchemy import Table
|
|
from sqlalchemy.engine import Inspector
|
|
from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint
|
|
from sqlalchemy.engine.interfaces import ReflectedIndex
|
|
from sqlalchemy.engine.interfaces import ReflectedUniqueConstraint
|
|
from sqlalchemy.engine.reflection import _ReflectionInfo
|
|
|
|
_INSP_KEYS = (
|
|
"columns",
|
|
"pk_constraint",
|
|
"foreign_keys",
|
|
"indexes",
|
|
"unique_constraints",
|
|
"table_comment",
|
|
"check_constraints",
|
|
"table_options",
|
|
)
|
|
_CONSTRAINT_INSP_KEYS = (
|
|
"pk_constraint",
|
|
"foreign_keys",
|
|
"indexes",
|
|
"unique_constraints",
|
|
"check_constraints",
|
|
)
|
|
|
|
|
|
class _InspectorConv:
|
|
__slots__ = ("inspector",)
|
|
|
|
def __new__(cls, inspector: Inspector) -> Self:
|
|
obj: Any
|
|
if sqla_compat.sqla_2:
|
|
obj = object.__new__(_SQLA2InspectorConv)
|
|
_SQLA2InspectorConv.__init__(obj, inspector)
|
|
else:
|
|
obj = object.__new__(_LegacyInspectorConv)
|
|
_LegacyInspectorConv.__init__(obj, inspector)
|
|
return cast(Self, obj)
|
|
|
|
def __init__(self, inspector: Inspector):
|
|
self.inspector = inspector
|
|
|
|
def pre_cache_tables(
|
|
self,
|
|
schema: str | None,
|
|
tablenames: list[str],
|
|
all_available_tablenames: Collection[str],
|
|
) -> None:
|
|
pass
|
|
|
|
def get_unique_constraints(
|
|
self, tname: str, schema: str | None
|
|
) -> list[ReflectedUniqueConstraint]:
|
|
raise NotImplementedError()
|
|
|
|
def get_indexes(
|
|
self, tname: str, schema: str | None
|
|
) -> list[ReflectedIndex]:
|
|
raise NotImplementedError()
|
|
|
|
def get_foreign_keys(
|
|
self, tname: str, schema: str | None
|
|
) -> list[ReflectedForeignKeyConstraint]:
|
|
raise NotImplementedError()
|
|
|
|
def reflect_table(self, table: Table) -> None:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class _LegacyInspectorConv(_InspectorConv):
|
|
|
|
def _apply_reflectinfo_conv(self, consts):
|
|
if not consts:
|
|
return consts
|
|
for const in consts:
|
|
if const["name"] is not None and not isinstance(
|
|
const["name"], conv
|
|
):
|
|
const["name"] = conv(const["name"])
|
|
return consts
|
|
|
|
def _apply_constraint_conv(self, consts):
|
|
if not consts:
|
|
return consts
|
|
for const in consts:
|
|
if const.name is not None and not isinstance(const.name, conv):
|
|
const.name = conv(const.name)
|
|
return consts
|
|
|
|
def get_indexes(
|
|
self, tname: str, schema: str | None
|
|
) -> list[ReflectedIndex]:
|
|
return self._apply_reflectinfo_conv(
|
|
self.inspector.get_indexes(tname, schema=schema)
|
|
)
|
|
|
|
def get_unique_constraints(
|
|
self, tname: str, schema: str | None
|
|
) -> list[ReflectedUniqueConstraint]:
|
|
return self._apply_reflectinfo_conv(
|
|
self.inspector.get_unique_constraints(tname, schema=schema)
|
|
)
|
|
|
|
def get_foreign_keys(
|
|
self, tname: str, schema: str | None
|
|
) -> list[ReflectedForeignKeyConstraint]:
|
|
return self._apply_reflectinfo_conv(
|
|
self.inspector.get_foreign_keys(tname, schema=schema)
|
|
)
|
|
|
|
def reflect_table(self, table: Table) -> None:
|
|
self.inspector.reflect_table(table, include_columns=None)
|
|
|
|
self._apply_constraint_conv(table.constraints)
|
|
self._apply_constraint_conv(table.indexes)
|
|
|
|
|
|
class _SQLA2InspectorConv(_InspectorConv):
|
|
|
|
def _pre_cache(
|
|
self,
|
|
schema: str | None,
|
|
tablenames: list[str],
|
|
all_available_tablenames: Collection[str],
|
|
info_key: str,
|
|
inspector_method: Any,
|
|
) -> None:
|
|
|
|
if info_key in self.inspector.info_cache:
|
|
return
|
|
|
|
# heuristic vendored from SQLAlchemy 2.0
|
|
# if more than 50% of the tables in the db are in filter_names load all
|
|
# the tables, since it's most likely faster to avoid a filter on that
|
|
# many tables. also if a dialect doesnt have a "multi" method then
|
|
# return the filter names
|
|
if tablenames and all_available_tablenames and len(tablenames) > 100:
|
|
fraction = len(tablenames) / len(all_available_tablenames)
|
|
else:
|
|
fraction = None
|
|
|
|
if (
|
|
fraction is None
|
|
or fraction <= 0.5
|
|
or not self.inspector.dialect._overrides_default(
|
|
inspector_method.__name__
|
|
)
|
|
):
|
|
optimized_filter_names = tablenames
|
|
else:
|
|
optimized_filter_names = None
|
|
|
|
try:
|
|
elements = inspector_method(
|
|
schema=schema, filter_names=optimized_filter_names
|
|
)
|
|
except NotImplementedError:
|
|
self.inspector.info_cache[info_key] = NotImplementedError
|
|
else:
|
|
self.inspector.info_cache[info_key] = elements
|
|
|
|
def _return_from_cache(
|
|
self,
|
|
tname: str,
|
|
schema: str | None,
|
|
info_key: str,
|
|
inspector_method: Any,
|
|
apply_constraint_conv: bool = False,
|
|
optional=True,
|
|
) -> Any:
|
|
not_in_cache = object()
|
|
|
|
if info_key in self.inspector.info_cache:
|
|
cache = self.inspector.info_cache[info_key]
|
|
if cache is NotImplementedError:
|
|
if optional:
|
|
return {}
|
|
else:
|
|
# maintain NotImplementedError as alembic compare
|
|
# uses these to determine classes of construct that it
|
|
# should not compare to DB elements
|
|
raise NotImplementedError()
|
|
|
|
individual = cache.get((schema, tname), not_in_cache)
|
|
|
|
if individual is not not_in_cache:
|
|
if apply_constraint_conv and individual is not None:
|
|
return self._apply_reflectinfo_conv(individual)
|
|
else:
|
|
return individual
|
|
|
|
try:
|
|
data = inspector_method(tname, schema=schema)
|
|
except NotImplementedError:
|
|
if optional:
|
|
return {}
|
|
else:
|
|
raise
|
|
|
|
if apply_constraint_conv:
|
|
return self._apply_reflectinfo_conv(data)
|
|
else:
|
|
return data
|
|
|
|
def get_unique_constraints(
|
|
self, tname: str, schema: str | None
|
|
) -> list[ReflectedUniqueConstraint]:
|
|
return self._return_from_cache(
|
|
tname,
|
|
schema,
|
|
"alembic_unique_constraints",
|
|
self.inspector.get_unique_constraints,
|
|
apply_constraint_conv=True,
|
|
optional=False,
|
|
)
|
|
|
|
def get_indexes(
|
|
self, tname: str, schema: str | None
|
|
) -> list[ReflectedIndex]:
|
|
return self._return_from_cache(
|
|
tname,
|
|
schema,
|
|
"alembic_indexes",
|
|
self.inspector.get_indexes,
|
|
apply_constraint_conv=True,
|
|
optional=False,
|
|
)
|
|
|
|
def get_foreign_keys(
|
|
self, tname: str, schema: str | None
|
|
) -> list[ReflectedForeignKeyConstraint]:
|
|
return self._return_from_cache(
|
|
tname,
|
|
schema,
|
|
"alembic_foreign_keys",
|
|
self.inspector.get_foreign_keys,
|
|
apply_constraint_conv=True,
|
|
)
|
|
|
|
def _apply_reflectinfo_conv(self, consts):
|
|
if not consts:
|
|
return consts
|
|
for const in consts if not isinstance(consts, dict) else [consts]:
|
|
if const["name"] is not None and not isinstance(
|
|
const["name"], conv
|
|
):
|
|
const["name"] = conv(const["name"])
|
|
return consts
|
|
|
|
def pre_cache_tables(
|
|
self,
|
|
schema: str | None,
|
|
tablenames: list[str],
|
|
all_available_tablenames: Collection[str],
|
|
) -> None:
|
|
for key in _INSP_KEYS:
|
|
keyname = f"alembic_{key}"
|
|
meth = getattr(self.inspector, f"get_multi_{key}")
|
|
|
|
self._pre_cache(
|
|
schema,
|
|
tablenames,
|
|
all_available_tablenames,
|
|
keyname,
|
|
meth,
|
|
)
|
|
|
|
def _make_reflection_info(
|
|
self, tname: str, schema: str | None
|
|
) -> _ReflectionInfo:
|
|
from sqlalchemy.engine.reflection import _ReflectionInfo
|
|
|
|
table_key = (schema, tname)
|
|
|
|
return _ReflectionInfo(
|
|
unreflectable={},
|
|
**{
|
|
key: {
|
|
table_key: self._return_from_cache(
|
|
tname,
|
|
schema,
|
|
f"alembic_{key}",
|
|
getattr(self.inspector, f"get_{key}"),
|
|
apply_constraint_conv=(key in _CONSTRAINT_INSP_KEYS),
|
|
)
|
|
}
|
|
for key in _INSP_KEYS
|
|
},
|
|
)
|
|
|
|
def reflect_table(self, table: Table) -> None:
|
|
ri = self._make_reflection_info(table.name, table.schema)
|
|
|
|
self.inspector.reflect_table(
|
|
table,
|
|
include_columns=None,
|
|
resolve_fks=False,
|
|
_reflect_info=ri,
|
|
)
|