Source code for sqlspec.driver._common

"""Common driver attributes and utilities."""

import graphlib
import hashlib
import logging
import re
from contextlib import suppress
from time import perf_counter
from typing import TYPE_CHECKING, Any, ClassVar, Final, Literal, NamedTuple, NoReturn, Protocol, cast, overload

import sqlglot
from mypy_extensions import mypyc_attr, trait
from sqlglot import exp
from typing_extensions import Self

from sqlspec.builder import QueryBuilder
from sqlspec.core import (
    SQL,
    CachedStatement,
    ParameterDeclaration,
    ParameterStyle,
    SQLResult,
    Statement,
    StatementConfig,
    TypedParameter,
    get_cache,
    get_cache_config,
    matches_param_type,
    split_sql_script,
)
from sqlspec.core._pool import get_processed_state_pool, get_sql_pool
from sqlspec.core.filters import find_filter as _find_filter_impl
from sqlspec.core.metrics import StackExecutionMetrics
from sqlspec.core.parameters import ParameterProcessor, structural_fingerprint, value_fingerprint
from sqlspec.core.statement import ProcessedState
from sqlspec.data_dictionary import ForeignKeyMetadata, VersionCacheResult, VersionInfo, get_data_dictionary_loader
from sqlspec.data_dictionary._registry import get_dialect_config
from sqlspec.driver._query_cache import STMT_CACHE_MAX_SIZE, CachedQuery, QueryCache
from sqlspec.driver._storage_helpers import (
    CAPABILITY_HINTS,
    arrow_table_needs_parameter_preparation,
    arrow_table_to_rows,
    attach_partition_telemetry,
    build_ingest_telemetry,
    coerce_arrow_table,
    create_storage_job,
)
from sqlspec.exceptions import (
    ImproperConfigurationError,
    NotFoundError,
    SQLFileNotFoundError,
    SQLSpecError,
    StorageCapabilityError,
)
from sqlspec.observability import ObservabilityRuntime, get_trace_context, resolve_db_system
from sqlspec.protocols import HasDataProtocol, HasExecuteProtocol, StatementProtocol
from sqlspec.utils.dispatch import TypeDispatcher
from sqlspec.utils.logging import get_logger, log_with_context
from sqlspec.utils.schema import to_schema as _to_schema_impl
from sqlspec.utils.text import normalize_identifier
from sqlspec.utils.type_guards import (
    has_array_interface,
    has_asdict_method,
    has_cursor_metadata,
    has_dtype_str,
    has_statement_type,
    has_typecode,
    has_typecode_and_len,
    is_dict_row,
    is_mapping_like,
    is_statement_filter,
)

if TYPE_CHECKING:
    from collections import abc
    from collections.abc import Awaitable, Callable
    from types import TracebackType

    from sqlspec.core import ArrowResult, FilterTypeT, StatementFilter
    from sqlspec.core.parameters._types import ConvertedParameters
    from sqlspec.core.stack import StatementStack
    from sqlspec.data_dictionary._types import DialectConfig
    from sqlspec.storage import (
        AsyncStoragePipeline,
        StorageBridgeJob,
        StorageCapabilities,
        StorageTelemetry,
        SyncStoragePipeline,
    )
    from sqlspec.typing import ArrowTable, SchemaT, StatementParameters


__all__ = (
    "DEFAULT_EXECUTION_RESULT",
    "EXEC_CURSOR_RESULT",
    "EXEC_ROWCOUNT_OVERRIDE",
    "EXEC_SPECIAL_DATA",
    "VERSION_GROUPS_MIN_FOR_MINOR",
    "VERSION_GROUPS_MIN_FOR_PATCH",
    "AsyncExceptionHandler",
    "CachedQuery",
    "CommonDriverAttributesMixin",
    "DataDictionaryDialectMixin",
    "DataDictionaryMixin",
    "ExecutionResult",
    "ScriptExecutionResult",
    "StackExecutionObserver",
    "SyncExceptionHandler",
    "describe_stack_statement",
    "handle_single_row_error",
    "hash_stack_operations",
    "make_cache_key_hashable",
    "resolve_db_system",
)


logger = get_logger("sqlspec.driver")

VERSION_GROUPS_MIN_FOR_MINOR = 1
VERSION_GROUPS_MIN_FOR_PATCH = 2

EXEC_CURSOR_RESULT: Final[int] = 0
EXEC_ROWCOUNT_OVERRIDE: Final[int] = 1
EXEC_SPECIAL_DATA: Final[int] = 2
DEFAULT_EXECUTION_RESULT: Final["tuple[object | None, int | None, object | None]"] = (None, None, None)

_DEFAULT_DML_METADATA: Final = {"status_message": "OK"}
_EMPTY_DML_DATA: Final[tuple[()]] = ()
_CONVERT_TO_TUPLE = object()
_CONVERT_TO_FROZENSET = object()
_TYPE_COERCION_DISPATCHERS: "dict[tuple[tuple[type, Any], ...], TypeDispatcher[Any]]" = {}

_CACHED_NAMED_STYLES: Final[frozenset[str]] = frozenset((
    ParameterStyle.NAMED_COLON.value,
    ParameterStyle.NAMED_AT.value,
    ParameterStyle.NAMED_DOLLAR.value,
    ParameterStyle.NAMED_PYFORMAT.value,
))


class SyncExceptionHandler(Protocol):
    """Protocol for synchronous exception handlers with deferred exception pattern.

    Exception handlers implement this protocol to avoid ABI boundary violations
    with mypyc-compiled code. Instead of raising exceptions from __exit__,
    handlers store mapped exceptions in pending_exception for the caller to raise.
    """

    @property
    def pending_exception(self) -> Exception | None: ...

    def __enter__(self) -> Self: ...

    def __exit__(
        self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None"
    ) -> bool: ...


class AsyncExceptionHandler(Protocol):
    """Protocol for asynchronous exception handlers with deferred exception pattern.

    Exception handlers implement this protocol to avoid ABI boundary violations
    with mypyc-compiled code. Instead of raising exceptions from __aexit__,
    handlers store mapped exceptions in pending_exception for the caller to raise.
    """

    @property
    def pending_exception(self) -> Exception | None: ...

    async def __aenter__(self) -> Self: ...

    async def __aexit__(
        self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None"
    ) -> bool: ...


class ScriptExecutionResult(NamedTuple):
    """Result from script execution with statement count information."""

    cursor_result: Any
    rowcount_override: int | None
    special_data: Any
    statement_count: int
    successful_statements: int


class ExecutionResult(NamedTuple):
    """Execution result containing all data needed for SQLResult building."""

    cursor_result: Any
    rowcount_override: int | None
    special_data: Any
    selected_data: "list[Any] | None"
    column_names: "list[str] | None"
    data_row_count: int | None
    statement_count: int | None
    successful_statements: int | None
    is_script_result: bool
    is_select_result: bool
    is_many_result: bool
    row_format: str = "dict"
    last_inserted_id: int | str | None = None


def make_cache_key_hashable(obj: Any) -> Any:
    """Recursively convert unhashable types to hashable ones for cache keys.

    Uses an iterative stack-based approach to avoid C-stack recursion limits
    in mypyc-compiled code.

    For array-like objects (NumPy arrays, Python arrays, etc.), we use structural
    info (dtype + shape or typecode + length) rather than content for cache keys.

    Collections are processed with stack entries that track (object, parent_list, index)
    so we can convert substructures in-place and then replace placeholders with tuples or frozensets
    only after their children are evaluated. Dictionaries are iterated in sorted order for determinism
    while sets fall back to a best-effort ordering if necessary.

    Args:
        obj: Object to make hashable.

    Returns:
        A hashable representation of the object. Collections become tuples,
        arrays become structural tuples like ("ndarray", dtype, shape).
    """
    if isinstance(obj, (int, str, bytes, bool, float, type(None))):
        return obj

    root: list[Any] = [obj]
    stack = [(obj, root, 0)]

    while stack:
        current_obj, parent, idx = stack.pop()

        if current_obj is _CONVERT_TO_TUPLE:
            parent[idx] = tuple(parent[idx])
            continue

        if current_obj is _CONVERT_TO_FROZENSET:
            parent[idx] = frozenset(parent[idx])
            continue

        if has_typecode_and_len(current_obj):
            parent[idx] = ("array", current_obj.typecode, len(current_obj))
            continue
        if has_typecode(current_obj):
            parent[idx] = ("array", current_obj.typecode)
            continue
        if has_array_interface(current_obj):
            try:
                dtype_str = current_obj.dtype.str if has_dtype_str(current_obj.dtype) else str(type(current_obj))
                shape = tuple(int(s) for s in current_obj.shape)
                parent[idx] = ("ndarray", dtype_str, shape)
            except (AttributeError, TypeError):
                try:
                    length = len(current_obj)
                    parent[idx] = ("array_like", type(current_obj).__name__, length)
                except (AttributeError, TypeError):
                    parent[idx] = ("array_like", type(current_obj).__name__)
            continue

        if isinstance(current_obj, (list, tuple)):
            new_list = [None] * len(current_obj)
            parent[idx] = new_list

            stack.append((_CONVERT_TO_TUPLE, parent, idx))

            stack.extend((current_obj[i], new_list, i) for i in range(len(current_obj) - 1, -1, -1))
            continue

        if isinstance(current_obj, dict):
            try:
                sorted_items = sorted(current_obj.items())
            except TypeError:
                sorted_items = list(current_obj.items())

            items_list = []
            for k, v in sorted_items:
                items_list.append([k, v])

            parent[idx] = items_list

            stack.append((_CONVERT_TO_TUPLE, parent, idx))

            for i in range(len(items_list) - 1, -1, -1):
                stack.extend(((_CONVERT_TO_TUPLE, items_list, i), (items_list[i][1], items_list[i], 1)))

            continue

        if isinstance(current_obj, set):
            try:
                sorted_list = sorted(current_obj)
            except TypeError:
                sorted_list = list(current_obj)

            new_list = [None] * len(sorted_list)
            parent[idx] = new_list

            stack.append((_CONVERT_TO_FROZENSET, parent, idx))

            stack.extend((sorted_list[i], new_list, i) for i in range(len(sorted_list) - 1, -1, -1))
            continue

        parent[idx] = current_obj

    return root[0]


def describe_stack_statement(statement: "StatementProtocol | str") -> str:
    """Return a readable representation of a stack statement for diagnostics."""
    if isinstance(statement, str):
        return statement
    if isinstance(statement, StatementProtocol):
        return statement.raw_sql or repr(statement)
    return repr(statement)


def handle_single_row_error(error: ValueError) -> "NoReturn":
    """Normalize single-row selection errors to SQLSpec exceptions."""
    message = str(error)
    if message.startswith("No result found"):
        msg = "No rows found"
        raise NotFoundError(msg) from error
    raise error


def hash_stack_operations(stack: "StatementStack") -> "tuple[str, ...]":
    """Return SHA256 fingerprints for statements contained in the stack."""
    hashes: list[str] = []
    for operation in stack.operations:
        summary = describe_stack_statement(operation.statement)
        if not isinstance(summary, str):
            summary = str(summary)
        digest = hashlib.sha256(summary.encode("utf-8")).hexdigest()
        hashes.append(digest[:16])
    return tuple(hashes)


def _apply_declared_optional_defaults(declared: "tuple[ParameterDeclaration, ...]", supplied: "dict[str, Any]") -> None:
    """Bind missing optional named params as SQL NULL."""
    for declaration in declared:
        if not declaration.required and declaration.name not in supplied:
            supplied[declaration.name] = None


def _check_declared_named_row(declared: "tuple[ParameterDeclaration, ...]", supplied: "dict[str, Any]") -> None:
    """Validate a single named-parameter mapping against declared params.

    Required params must be present. Missing optional params are bound as
    ``None`` so SQL receives ``NULL``. A present non-``None`` value whose
    declared type resolves via the registry must satisfy that matcher.
    Unresolved types are documentation-only. Extra keys are ignored.
    """
    _apply_declared_optional_defaults(declared, supplied)
    for declaration in declared:
        name = declaration.name
        if name not in supplied:
            msg = f"Missing required parameter '{name}' for declared SQL statement."
            raise SQLSpecError(msg)
        value = supplied[name]
        if value is None:
            continue
        if not matches_param_type(declaration.type_str, value):
            msg = f"Parameter '{name}' expected type '{declaration.type_str}' but got {type(value).__name__}."
            raise SQLSpecError(msg)


def _validate_declared_parameters(sql_statement: "SQL") -> None:
    """Enforce declared-parameter contracts on the original user params.

    No-op unless the statement carries declarations and is not a script. Runs before
    driver style conversion, so declared names and raw values are intact. Named binding
    is validated for presence and type; ``execute_many`` checks the first row only;
    positional binding is skipped (arity is validated at load time).
    """
    declared = sql_statement.declared_parameters
    if not declared or sql_statement.is_script:
        return
    if sql_statement.is_many:
        rows = sql_statement.positional_parameters
        if rows and isinstance(rows[0], dict):
            for row in rows:
                if isinstance(row, dict):
                    _apply_declared_optional_defaults(declared, row)
            _check_declared_named_row(declared, rows[0])
        return
    if sql_statement.positional_parameters:
        return
    _check_declared_named_row(declared, sql_statement.named_parameters)


class StackExecutionObserver:
    """Context manager that aggregates telemetry for stack execution."""

    __slots__ = (
        "continue_on_error",
        "driver",
        "hashed_operations",
        "metrics",
        "native_pipeline",
        "runtime",
        "span",
        "stack",
        "started",
    )

    def __init__(
        self,
        driver: "CommonDriverAttributesMixin",
        stack: "StatementStack",
        continue_on_error: bool,
        native_pipeline: bool,
    ) -> None:
        self.driver = driver
        self.stack = stack
        self.continue_on_error = continue_on_error
        self.native_pipeline = native_pipeline
        self.runtime = driver.observability
        self.metrics = StackExecutionMetrics(
            adapter=type(driver).__name__,
            statement_count=len(stack.operations),
            continue_on_error=continue_on_error,
            native_pipeline=native_pipeline,
            forced_disable=driver.stack_native_disabled,
        )
        self.hashed_operations = hash_stack_operations(stack)
        self.span: Any | None = None
        self.started = 0.0

    def __enter__(self) -> Self:
        self.started = perf_counter()
        trace_id, span_id = get_trace_context()
        attributes = {
            "sqlspec.stack.statement_count": len(self.stack.operations),
            "sqlspec.stack.continue_on_error": self.continue_on_error,
            "sqlspec.stack.native_pipeline": self.native_pipeline,
            "sqlspec.stack.forced_disable": self.driver.stack_native_disabled,
        }
        self.span = self.runtime.start_span("sqlspec.stack.execute", attributes=attributes)
        log_with_context(
            logger,
            logging.DEBUG,
            "stack.execute.start",
            driver=type(self.driver).__name__,
            db_system=resolve_db_system(type(self.driver).__name__),
            stack_size=len(self.stack.operations),
            continue_on_error=self.continue_on_error,
            native_pipeline=self.native_pipeline,
            forced_disable=self.driver.stack_native_disabled,
            hashed_operations=self.hashed_operations,
            trace_id=trace_id,
            span_id=span_id,
        )
        return self

    def __exit__(
        self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None"
    ) -> Literal[False]:
        duration = perf_counter() - self.started
        self.metrics.record_duration(duration)
        if isinstance(exc_val, Exception):
            self.metrics.record_error(exc_val)
        self.runtime.span_manager.end_span(self.span, error=exc_val if exc_val is not None else None)
        self.metrics.emit(self.runtime)
        level = logging.ERROR if exc_val is not None else logging.DEBUG
        trace_id, span_id = get_trace_context()
        log_with_context(
            logger,
            level,
            "stack.execute.failed" if exc_val is not None else "stack.execute.complete",
            driver=type(self.driver).__name__,
            db_system=resolve_db_system(type(self.driver).__name__),
            stack_size=len(self.stack.operations),
            continue_on_error=self.continue_on_error,
            native_pipeline=self.native_pipeline,
            forced_disable=self.driver.stack_native_disabled,
            hashed_operations=self.hashed_operations,
            duration_ms=duration * 1000,
            error_type=type(exc_val).__name__ if exc_val is not None else None,
            trace_id=trace_id,
            span_id=span_id,
        )
        return False

    def record_operation_error(self, error: Exception) -> None:
        """Record an operation error when continue-on-error is enabled."""
        self.metrics.record_operation_error(error)


@mypyc_attr(allow_interpreted_subclasses=True)
@trait
class DataDictionaryDialectMixin:
    """Mixin providing dialect SQL helpers for data dictionaries."""

    __slots__ = ()

    dialect: "ClassVar[str]"

    def get_dialect_config(self) -> "DialectConfig":
        """Return the dialect configuration for this data dictionary."""
        return get_dialect_config(type(self).dialect)

    def get_query(self, name: str) -> "SQL":
        """Return a named SQL query for this dialect."""
        loader = get_data_dictionary_loader()
        return loader.get_query(type(self).dialect, name)

    def get_query_text(self, name: str) -> str:
        """Return raw SQL text for a named query for this dialect."""
        loader = get_data_dictionary_loader()
        return loader.get_query_text(type(self).dialect, name)

    def get_query_text_or_none(self, name: str) -> "str | None":
        """Return raw SQL text for a named query or None if missing."""
        try:
            return self.get_query_text(name)
        except SQLFileNotFoundError:
            return None

    def resolve_schema(self, schema: "str | None") -> "str | None":
        """Return a schema name using dialect defaults when missing."""
        config = self.get_dialect_config()
        if schema is not None:
            return normalize_identifier(schema, config.name)
        if config.default_schema is None:
            return None
        return normalize_identifier(config.default_schema, config.name)

    def resolve_identifier(self, identifier: str) -> str:
        """Return a dialect-normalized identifier value."""
        return normalize_identifier(identifier, self.get_dialect_config().name)

    def resolve_feature_flag(self, feature: str, version: "VersionInfo | None") -> bool:
        """Resolve a feature flag using dialect config and version info."""
        config = self.get_dialect_config()
        flag = config.get_feature_flag(feature)
        if flag is not None:
            return flag
        required_version = config.get_feature_version(feature)
        if required_version is None or version is None:
            return False
        return bool(version >= required_version)

    def get_default_features(self) -> "list[str]":
        """Get default feature flags. Overridden by DataDictionaryMixin."""
        return []

    def list_available_features(self) -> "list[str]":
        """List all features that can be checked via get_feature_flag.

        Returns:
            List of feature names this data dictionary supports
        """
        config = self.get_dialect_config()
        features = set(self.get_default_features())
        features.update(config.feature_flags.keys())
        features.update(config.feature_versions.keys())
        return sorted(features)


@mypyc_attr(allow_interpreted_subclasses=True)
@trait
class DataDictionaryMixin:
    """Mixin providing common data dictionary functionality.

    Includes version caching to avoid repeated database queries when checking
    feature flags or optimal types.
    """

    __slots__ = ()

    dialect: "ClassVar[str]"

    _version_cache: "dict[int, VersionInfo | None]"
    _version_fetch_attempted: "set[int]"

    def get_cached_version(self, driver_id: int) -> "VersionCacheResult":
        """Get cached version info for a driver.

        Args:
            driver_id: The id() of the driver instance.

        Returns:
            Tuple of (was_cached, version_info). If was_cached is False,
            the caller should fetch the version and call cache_version().
        """
        if driver_id in self._version_fetch_attempted:
            return True, self._version_cache.get(driver_id)
        return False, None
def cache_version(self, driver_id: int, version: "VersionInfo | None") -> None: """Cache version info for a driver. Args: driver_id: The id() of the driver instance. version: The version info to cache (can be None if detection failed). """ self._version_fetch_attempted.add(driver_id) if version is not None: self._version_cache[driver_id] = version