Source code for sqlspec.builder._base

"""Base query builder with validation and parameter binding.

Provides abstract base classes and core functionality for SQL query builders.
"""

import hashlib
import re
import uuid
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Mapping
from typing import Any, NoReturn, cast

import sqlglot
from sqlglot import Dialect, exp
from sqlglot.dialects.dialect import DialectType
from sqlglot.errors import ParseError as SQLGlotParseError
from sqlglot.optimizer import RULES, optimize
from sqlglot.optimizer.optimize_joins import optimize_joins as _optimize_joins_rule
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates as _pushdown_predicates_rule
from sqlglot.optimizer.simplify import simplify as _simplify_rule
from typing_extensions import Self

from sqlspec.builder._vector_distance import has_vector_distance_ancestor
from sqlspec.core import (
    SQL,
    ParameterStyle,
    ParameterStyleConfig,
    SQLResult,
    StatementConfig,
    get_cache,
    get_cache_config,
    hash_expression,
    hash_optimized_expression,
)
from sqlspec.core.filters import StatementFilter
from sqlspec.exceptions import SQLBuilderError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.type_guards import has_expression_and_parameters, has_name, has_with_method, is_expression

__all__ = ("BuiltQuery", "ExpressionBuilder", "QueryBuilder")

MAX_PARAMETER_COLLISION_ATTEMPTS = 1000
PARAMETER_INDEX_PATTERN = re.compile(r"^param_(?P<index>\d+)$")


class _ExpressionParameterizer:
    __slots__ = ("_builder",)

    def __init__(self, builder: "QueryBuilder") -> None:
        self._builder = builder

    def __call__(self, node: exp.Expr) -> exp.Expr:
        if isinstance(node, exp.Literal):
            if node.this in {True, False, None}:
                return node

            parent = node.parent
            if isinstance(parent, exp.Array) and has_vector_distance_ancestor(node):
                return node

            value = node.this
            if node.is_number and isinstance(node.this, str):
                try:
                    value = float(node.this) if "." in node.this or "e" in node.this.lower() else int(node.this)
                except ValueError:
                    value = node.this

            param_name = self._builder.add_parameter_for_expression(value, context="where")
            return exp.Placeholder(this=param_name)
        return node


class _PlaceholderReplacer:
    __slots__ = ("_param_mapping",)

    def __init__(self, param_mapping: dict[str, str]) -> None:
        self._param_mapping = param_mapping

    def __call__(self, node: exp.Expr) -> exp.Expr:
        if isinstance(node, exp.Placeholder) and str(node.this) in self._param_mapping:
            return exp.Placeholder(this=self._param_mapping[str(node.this)])
        return node


def _unquote_identifier(node: exp.Expr) -> exp.Expr:
    if isinstance(node, exp.Identifier):
        node.set("quoted", False)
    return node


logger = get_logger(__name__)


class BuiltQuery:
    """SQL query with bound parameters."""

    __slots__ = ("dialect", "parameters", "sql")

    def __init__(self, sql: str, parameters: dict[str, Any] | None = None, dialect: DialectType | None = None) -> None:
        self.sql = sql
        self.parameters = parameters if parameters is not None else {}
        self.dialect = dialect
def __repr__(self) -> str: parameter_keys = sorted(self.parameters.keys()) return f"BuiltQuery(sql={self.sql!r}, parameters={parameter_keys!r}, dialect={self.dialect!r})" def __eq__(self, other: object) -> bool: if not isinstance(other, BuiltQuery): return NotImplemented return self.sql == other.sql and self.parameters == other.parameters and self.dialect == other.dialect def __hash__(self) -> int: return hash((self.sql, frozenset(self.parameters.items()), self.dialect))