Source code for sqlspec.core.statement

"""SQL statement and configuration management."""

import hashlib
import uuid
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Final, TypeAlias

import sqlglot
from mypy_extensions import mypyc_attr
from sqlglot import Dialect, exp
from sqlglot.errors import ParseError

import sqlspec.exceptions
from sqlspec.core import pipeline
from sqlspec.core._pool import get_processed_state_pool, get_sql_pool
from sqlspec.core.cache import FiltersView
from sqlspec.core.compiler import OperationProfile, OperationType
from sqlspec.core.explain import ExplainFormat, ExplainOptions
from sqlspec.core.hashing import hash_filters
from sqlspec.core.parameters import (
    ParameterConverter,
    ParameterDeclaration,
    ParameterProcessor,
    ParameterProfile,
    ParameterStyle,
    ParameterStyleConfig,
    ParameterValidator,
    structural_fingerprint,
)
from sqlspec.core.query_modifiers import (
    apply_column_pruning,
    apply_limit,
    apply_offset,
    apply_select_only,
    apply_where,
    create_between_condition,
    create_condition,
    create_in_condition,
    create_not_in_condition,
    expr_eq,
    expr_gt,
    expr_gte,
    expr_ilike,
    expr_is_not_null,
    expr_is_null,
    expr_like,
    expr_lt,
    expr_lte,
    expr_neq,
    extract_column_name,
    safe_modify_with_cte,
)
from sqlspec.core.sqlcommenter import create_sqlcommenter_statement_transformer
from sqlspec.observability import resolve_db_system
from sqlspec.typing import Empty, EmptyEnum
from sqlspec.utils.logging import get_logger
from sqlspec.utils.type_guards import is_statement_filter, supports_where

if TYPE_CHECKING:
    from collections.abc import Callable

    from sqlglot.dialects.dialect import DialectType

    from sqlspec.builder import QueryBuilder
    from sqlspec.core.filters import StatementFilter


__all__ = (
    "SQL",
    "ProcessedState",
    "Statement",
    "StatementConfig",
    "get_default_config",
    "get_default_parameter_config",
)
logger = get_logger("sqlspec.core.statement")

RETURNS_ROWS_OPERATIONS: Final = {"SELECT", "WITH", "VALUES", "TABLE", "SHOW", "DESCRIBE", "PRAGMA"}
MODIFYING_OPERATIONS: Final = {"INSERT", "UPDATE", "DELETE", "MERGE", "UPSERT"}
_ORDER_PARTS_COUNT: Final = 2
_MAX_PARAM_COLLISION_ATTEMPTS: Final = 1000


def _parse_order_item(order_item: str, dialect: "str | None", enable_parsing: bool) -> exp.Expr:
    """Parse a single ORDER BY item string into a SQLGlot expression."""
    normalized = order_item.strip()
    if not normalized:
        return exp.column(order_item)

    if enable_parsing:
        try:
            parsed = sqlglot.parse_one(normalized, dialect=dialect, into=exp.Ordered)
        except ParseError:
            parsed = None
        if parsed is not None:
            return parsed

    parts = normalized.rsplit(None, 1)
    if len(parts) == _ORDER_PARTS_COUNT and parts[1].lower() in {"asc", "desc"}:
        base_expr = exp.column(parts[0]) if parts[0] else exp.column(normalized)
        return base_expr.desc() if parts[1].lower() == "desc" else base_expr.asc()

    return exp.column(normalized)


SQL_CONFIG_SLOTS: Final = (
    "dialect",
    "enable_analysis",
    "enable_caching",
    "enable_column_pruning",
    "enable_expression_simplification",
    "enable_parameter_type_wrapping",
    "enable_parsing",
    "enable_sqlcommenter",
    "enable_transformations",
    "enable_validation",
    "execution_mode",
    "execution_args",
    "output_transformer",
    "sqlcommenter_attributes",
    "sqlcommenter_enable_context",
    "sqlcommenter_enable_traceparent",
    "statement_transformers",
    "parameter_config",
    "parameter_converter",
    "parameter_validator",
    "_fingerprint_cache",
    "_hash_cache",
    "_is_frozen",
)

_PUBLIC_CONFIG_FIELDS: Final = frozenset((
    "dialect",
    "enable_analysis",
    "enable_caching",
    "enable_column_pruning",
    "enable_expression_simplification",
    "enable_parameter_type_wrapping",
    "enable_parsing",
    "enable_sqlcommenter",
    "enable_transformations",
    "enable_validation",
    "execution_mode",
    "execution_args",
    "output_transformer",
    "sqlcommenter_attributes",
    "sqlcommenter_enable_context",
    "sqlcommenter_enable_traceparent",
    "statement_transformers",
    "parameter_config",
    "parameter_converter",
    "parameter_validator",
))

PROCESSED_STATE_SLOTS: Final = (
    "compiled_sql",
    "execution_parameters",
    "parsed_expression",
    "operation_type",
    "input_named_parameters",
    "applied_wrap_types",
    "filter_hash",
    "parameter_fingerprint",
    "parameter_casts",
    "parameter_profile",
    "operation_profile",
    "validation_errors",
    "is_many",
)

SQL_SLOTS: Final = (
    "_compiled_from_cache",
    "_declared_parameters",
    "_dialect",
    "_filters",
    "_hash",
    "_is_cache_direct",
    "_is_many",
    "_is_script",
    "_named_parameters",
    "_original_parameters",
    "_pooled",
    "_positional_parameters",
    "_processed_state",
    "_raw_expression",
    "_raw_sql",
    "_rebind_processor",
    "_sql_param_counters",
    "_statement_config",
)


@mypyc_attr(allow_interpreted_subclasses=False)
class ProcessedState:
    """Processing results for SQL statements.

    Contains the compiled SQL, execution parameters, parsed expression,
    operation type, and validation errors for a processed SQL statement.
    """

    __slots__ = PROCESSED_STATE_SLOTS
    operation_type: "OperationType"

    def __init__(
        self,
        compiled_sql: str,
        execution_parameters: Any,
        parsed_expression: "exp.Expr | None" = None,
        operation_type: "OperationType" = "COMMAND",
        input_named_parameters: "tuple[str, ...] | None" = None,
        applied_wrap_types: bool = False,
        filter_hash: int = 0,
        parameter_fingerprint: Any | None = None,
        parameter_casts: "dict[int, str] | None" = None,
        validation_errors: "list[str] | None" = None,
        parameter_profile: "ParameterProfile | None" = None,
        operation_profile: "OperationProfile | None" = None,
        is_many: bool = False,
    ) -> None:
        self.compiled_sql = compiled_sql
        self.execution_parameters = execution_parameters
        self.parsed_expression = parsed_expression
        self.operation_type = operation_type
        self.input_named_parameters = input_named_parameters or ()
        self.applied_wrap_types = applied_wrap_types
        self.filter_hash = filter_hash
        self.parameter_fingerprint = parameter_fingerprint
        self.parameter_casts = parameter_casts or {}
        self.validation_errors = validation_errors or []
        self.parameter_profile = parameter_profile or ParameterProfile.empty()
        self.operation_profile = operation_profile or OperationProfile.empty()
        self.is_many = is_many
def __hash__(self) -> int: return hash((self.compiled_sql, str(self.execution_parameters), self.operation_type)) def __reduce__(self) -> "tuple[Any, ...]": """Reconstruct via the public ctor so copy/pickle work on mypyc native classes.""" return ( ProcessedState, ( self.compiled_sql, self.execution_parameters, self.parsed_expression, self.operation_type, self.input_named_parameters, self.applied_wrap_types, self.filter_hash, self.parameter_fingerprint, self.parameter_casts, self.validation_errors, self.parameter_profile, self.operation_profile, self.is_many, ), )