Source code for sqlspec.builder._factory

"""SQL factory for creating SQL builders and column expressions.

Provides statement builders (select, insert, update, etc.) and column expressions.
"""

import hashlib
import logging
from typing import TYPE_CHECKING, Any, TypeVar, Union, cast, final

import sqlglot
from sqlglot import exp
from sqlglot.dialects.dialect import DialectType
from sqlglot.errors import ParseError as SQLGlotParseError

from sqlspec.builder._base import QueryBuilder
from sqlspec.builder._column import Column
from sqlspec.builder._ddl import (
    AlterTable,
    CommentOn,
    CreateIndex,
    CreateMaterializedView,
    CreateSchema,
    CreateTable,
    CreateTableAsSelect,
    CreateView,
    DropIndex,
    DropMaterializedView,
    DropSchema,
    DropTable,
    DropView,
    RenameTable,
    Truncate,
)
from sqlspec.builder._delete import Delete
from sqlspec.builder._explain import Explain
from sqlspec.builder._expression_wrappers import (
    AggregateExpression,
    ConversionExpression,
    FunctionExpression,
    MathExpression,
    StringExpression,
)
from sqlspec.builder._insert import Insert
from sqlspec.builder._join import JoinBuilder, create_join_builder
from sqlspec.builder._merge import Merge
from sqlspec.builder._parsing_utils import extract_expression, to_expression
from sqlspec.builder._select import Case, Select, SubqueryBuilder, WindowFunctionBuilder
from sqlspec.builder._update import Update
from sqlspec.core import SQL
from sqlspec.core.explain import ExplainFormat, ExplainOptions
from sqlspec.exceptions import SQLBuilderError
from sqlspec.utils.logging import get_logger

if TYPE_CHECKING:
    from collections.abc import Mapping, Sequence

    from sqlspec.builder._expression_wrappers import ExpressionWrapper
    from sqlspec.protocols import SQLBuilderProtocol


__all__ = (
    "AlterTable",
    "Case",
    "Column",
    "CommentOn",
    "CreateIndex",
    "CreateMaterializedView",
    "CreateSchema",
    "CreateTable",
    "CreateTableAsSelect",
    "CreateView",
    "Delete",
    "DropIndex",
    "DropMaterializedView",
    "DropSchema",
    "DropTable",
    "DropView",
    "Explain",
    "Insert",
    "Merge",
    "RenameTable",
    "SQLFactory",
    "Select",
    "Truncate",
    "Update",
    "WindowFunctionBuilder",
    "build_copy_from_statement",
    "build_copy_statement",
    "build_copy_to_statement",
    "sql",
)

logger = get_logger("sqlspec.builder.factory")

BuilderT = TypeVar("BuilderT", bound=QueryBuilder)

MIN_SQL_LIKE_STRING_LENGTH = 6
MIN_DECODE_ARGS = 2
SQL_STARTERS = {
    "SELECT",
    "INSERT",
    "UPDATE",
    "DELETE",
    "MERGE",
    "WITH",
    "CALL",
    "DECLARE",
    "BEGIN",
    "END",
    "CREATE",
    "DROP",
    "ALTER",
    "TRUNCATE",
    "RENAME",
    "GRANT",
    "REVOKE",
    "SET",
    "SHOW",
    "USE",
    "EXPLAIN",
    "OPTIMIZE",
    "VACUUM",
    "COPY",
}


def _fingerprint_sql(sql: str) -> str:
    digest = hashlib.sha256(sql.encode("utf-8", errors="replace")).hexdigest()
    return digest[:12]


def _normalize_copy_dialect(dialect: DialectType | None) -> str:
    if dialect is None:
        return "postgres"
    if isinstance(dialect, str):
        return dialect
    return str(dialect)


def _to_copy_schema(table: str, columns: "Sequence[str] | None") -> exp.Expr:
    base = exp.table_(table)
    if not columns:
        return base
    column_nodes = [exp.column(column_name) for column_name in columns]
    return exp.Schema(this=base, expressions=column_nodes)


def _build_copy_expression(
    *, direction: str, table: str, location: str, columns: "Sequence[str] | None", options: "Mapping[str, Any] | None"
) -> exp.Copy:
    copy_args: dict[str, Any] = {"this": _to_copy_schema(table, columns), "files": [exp.Literal.string(location)]}

    if direction == "from":
        copy_args["kind"] = True
    elif direction == "to":
        copy_args["kind"] = False

    if options:
        params: list[exp.CopyParameter] = []
        for key, value in options.items():
            identifier = exp.Var(this=str(key).upper())
            value_expression: exp.Expr
            if isinstance(value, bool):
                value_expression = exp.Boolean(this=value)
            elif value is None:
                value_expression = exp.null()
            elif isinstance(value, (int, float)):
                value_expression = exp.Literal.number(value)
            elif isinstance(value, (list, tuple)):
                elements = [exp.Literal.string(str(item)) for item in value]
                value_expression = exp.Array(expressions=elements)
            else:
                value_expression = exp.Literal.string(str(value))
            params.append(exp.CopyParameter(this=identifier, expression=value_expression))
        copy_args["params"] = params

    return exp.Copy(**copy_args)


def build_copy_statement(
    *,
    direction: str,
    table: str,
    location: str,
    columns: "Sequence[str] | None" = None,
    options: "Mapping[str, Any] | None" = None,
    dialect: DialectType | None = None,
) -> SQL:
    expression = _build_copy_expression(
        direction=direction, table=table, location=location, columns=columns, options=options
    )
    rendered = expression.sql(dialect=_normalize_copy_dialect(dialect))
    return SQL(rendered)
def build_copy_from_statement( table: str, source: str, *, columns: "Sequence[str] | None" = None, options: "Mapping[str, Any] | None" = None, dialect: DialectType | None = None, ) -> SQL: return build_copy_statement( direction="from", table=table, location=source, columns=columns, options=options, dialect=dialect )