Source code for sqlspec.builder._ddl

"""DDL statement builders.

Provides builders for DDL operations including CREATE, DROP, ALTER,
TRUNCATE, and other schema manipulation statements.
"""

from typing import TYPE_CHECKING, Any, cast

from mypy_extensions import trait
from sqlglot import exp
from typing_extensions import Self

from sqlspec.builder._base import BuiltQuery, QueryBuilder
from sqlspec.builder._select import Select
from sqlspec.core import SQL, SQLResult
from sqlspec.utils.type_guards import has_sqlglot_expression, has_with_method

if TYPE_CHECKING:
    from sqlglot.dialects.dialect import DialectType

    from sqlspec.builder._column import ColumnExpression
    from sqlspec.core import StatementConfig

__all__ = (
    "AlterOperation",
    "AlterTable",
    "ColumnDefinition",
    "CommentOn",
    "ConstraintDefinition",
    "CreateIndex",
    "CreateMaterializedView",
    "CreateSchema",
    "CreateTable",
    "CreateTableAsSelect",
    "CreateView",
    "DDLBuilder",
    "DropIndex",
    "DropMaterializedView",
    "DropSchema",
    "DropTable",
    "DropView",
    "RenameTable",
    "Truncate",
)

CONSTRAINT_TYPE_PRIMARY_KEY = "PRIMARY KEY"
CONSTRAINT_TYPE_FOREIGN_KEY = "FOREIGN KEY"
CONSTRAINT_TYPE_UNIQUE = "UNIQUE"
CONSTRAINT_TYPE_CHECK = "CHECK"

FOREIGN_KEY_ACTION_CASCADE = "CASCADE"
FOREIGN_KEY_ACTION_SET_NULL = "SET NULL"
FOREIGN_KEY_ACTION_SET_DEFAULT = "SET DEFAULT"
FOREIGN_KEY_ACTION_RESTRICT = "RESTRICT"
FOREIGN_KEY_ACTION_NO_ACTION = "NO ACTION"

VALID_FOREIGN_KEY_ACTIONS = {
    FOREIGN_KEY_ACTION_CASCADE,
    FOREIGN_KEY_ACTION_SET_NULL,
    FOREIGN_KEY_ACTION_SET_DEFAULT,
    FOREIGN_KEY_ACTION_RESTRICT,
    FOREIGN_KEY_ACTION_NO_ACTION,
}

VALID_CONSTRAINT_TYPES = {
    CONSTRAINT_TYPE_PRIMARY_KEY,
    CONSTRAINT_TYPE_FOREIGN_KEY,
    CONSTRAINT_TYPE_UNIQUE,
    CONSTRAINT_TYPE_CHECK,
}

CURRENT_TIMESTAMP_KEYWORD = "CURRENT_TIMESTAMP"
CURRENT_DATE_KEYWORD = "CURRENT_DATE"
CURRENT_TIME_KEYWORD = "CURRENT_TIME"


def build_column_expression(col: "ColumnDefinition") -> "exp.Expr":
    """Build SQLGlot expression for a column definition."""
    col_def = exp.ColumnDef(this=exp.to_identifier(col.name), kind=exp.DataType.build(col.dtype))

    constraints: list[exp.ColumnConstraint] = []

    if col.not_null:
        constraints.append(exp.ColumnConstraint(kind=exp.NotNullColumnConstraint()))

    if col.primary_key:
        constraints.append(exp.ColumnConstraint(kind=exp.PrimaryKeyColumnConstraint()))

    if col.unique:
        constraints.append(exp.ColumnConstraint(kind=exp.UniqueColumnConstraint()))

    if col.auto_increment:
        constraints.append(exp.ColumnConstraint(kind=exp.AutoIncrementColumnConstraint()))

    if col.default is not None:
        default_expr: exp.Expr | None = None
        if isinstance(col.default, str):
            default_upper = col.default.upper()
            if default_upper == CURRENT_TIMESTAMP_KEYWORD:
                default_expr = exp.CurrentTimestamp()
            elif default_upper == CURRENT_DATE_KEYWORD:
                default_expr = exp.CurrentDate()
            elif default_upper == CURRENT_TIME_KEYWORD:
                default_expr = exp.CurrentTime()
            elif "(" in col.default:
                default_expr = exp.maybe_parse(col.default)
            else:
                default_expr = exp.convert(col.default)
        else:
            default_expr = exp.convert(col.default)

        constraints.append(exp.ColumnConstraint(kind=exp.DefaultColumnConstraint(this=default_expr)))

    if col.check:
        constraints.append(exp.ColumnConstraint(kind=exp.Check(this=exp.maybe_parse(col.check))))

    if col.comment:
        constraints.append(exp.ColumnConstraint(kind=exp.CommentColumnConstraint(this=exp.convert(col.comment))))

    if col.generated:
        constraints.append(
            exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=exp.maybe_parse(col.generated)))
        )

    if col.collate:
        constraints.append(exp.ColumnConstraint(kind=exp.CollateColumnConstraint(this=exp.to_identifier(col.collate))))

    if constraints:
        col_def.set("constraints", constraints)

    return col_def


def build_constraint_expression(constraint: "ConstraintDefinition") -> "exp.Expr | None":
    """Build SQLGlot expression for a table constraint."""
    if constraint.constraint_type == CONSTRAINT_TYPE_PRIMARY_KEY:
        pk_constraint = exp.PrimaryKey(expressions=[exp.to_identifier(col) for col in constraint.columns])

        if constraint.name:
            return exp.Constraint(this=exp.to_identifier(constraint.name), expression=pk_constraint)
        return pk_constraint

    if constraint.constraint_type == CONSTRAINT_TYPE_FOREIGN_KEY:
        fk_options: list[str] = []
        if constraint.deferrable:
            if constraint.initially_deferred:
                fk_options.append("DEFERRABLE INITIALLY DEFERRED")
            else:
                fk_options.append("DEFERRABLE INITIALLY IMMEDIATE")

        fk_constraint = exp.ForeignKey(
            expressions=[exp.to_identifier(col) for col in constraint.columns],
            reference=exp.Reference(
                this=exp.to_table(constraint.references_table) if constraint.references_table else None,
                expressions=[exp.to_identifier(col) for col in constraint.references_columns],
                on_delete=constraint.on_delete,
                on_update=constraint.on_update,
            ),
            options=fk_options or None,
        )

        if constraint.name:
            return exp.Constraint(this=exp.to_identifier(constraint.name), expression=fk_constraint)
        return fk_constraint

    if constraint.constraint_type == CONSTRAINT_TYPE_UNIQUE:
        unique_constraint = exp.UniqueKeyProperty(expressions=[exp.to_identifier(col) for col in constraint.columns])

        if constraint.name:
            return exp.Constraint(this=exp.to_identifier(constraint.name), expression=unique_constraint)
        return unique_constraint

    if constraint.constraint_type == CONSTRAINT_TYPE_CHECK:
        check_expr = exp.Check(
            this=constraint.condition_expr
            if constraint.condition_expr is not None
            else exp.maybe_parse(constraint.condition)
            if constraint.condition
            else None
        )

        if constraint.name:
            return exp.Constraint(this=exp.to_identifier(constraint.name), expression=check_expr)
        return check_expr

    return None


class DDLBuilder(QueryBuilder):
    """Base class for DDL builders (CREATE, DROP, ALTER, etc)."""

    __slots__ = ()

    def __init__(self, dialect: "DialectType" = None) -> None:
        super().__init__(dialect=dialect)
        self._expression: exp.Expr | None = None
def _create_base_expression(self) -> exp.Expr: msg = "Subclasses must implement _create_base_expression." raise NotImplementedError(msg) def _resolve_select_query(self, query: object, context: str, *, require_select_type: bool = True) -> exp.Expr: select_parameters: dict[str, Any] | None = None if isinstance(query, SQL): select_expr = query.expression select_parameters = query.parameters elif isinstance(query, Select): select_expr = query.get_expression() select_parameters = query.parameters elif isinstance(query, str): select_expr = exp.maybe_parse(query) elif isinstance(query, exp.Expr): select_expr = query else: self._raise_sql_builder_error(f"Unsupported type for SELECT query in {context}.") if select_expr is None or (require_select_type and not isinstance(select_expr, exp.Select)): self._raise_sql_builder_error("SELECT query must be a valid SELECT expression.") if select_parameters: for p_name, p_value in select_parameters.items(): self._parameters[p_name] = p_value return select_expr @property def _expected_result_type(self) -> "type[SQLResult]": return SQLResult def build(self, dialect: "DialectType" = None) -> "BuiltQuery": if self._expression is None: self._expression = self._create_base_expression() return super().build(dialect=dialect)