"""Common driver attributes and utilities."""
import graphlib
import hashlib
import logging
import re
from contextlib import suppress
from time import perf_counter
from typing import TYPE_CHECKING, Any, ClassVar, Final, Literal, NamedTuple, NoReturn, Protocol, cast, overload
import sqlglot
from mypy_extensions import mypyc_attr, trait
from sqlglot import exp
from typing_extensions import Self
from sqlspec.builder import QueryBuilder
from sqlspec.core import (
SQL,
CachedStatement,
ParameterDeclaration,
ParameterStyle,
SQLResult,
Statement,
StatementConfig,
TypedParameter,
get_cache,
get_cache_config,
matches_param_type,
split_sql_script,
)
from sqlspec.core._pool import get_processed_state_pool, get_sql_pool
from sqlspec.core.filters import find_filter as _find_filter_impl
from sqlspec.core.metrics import StackExecutionMetrics
from sqlspec.core.parameters import ParameterProcessor, structural_fingerprint, value_fingerprint
from sqlspec.core.statement import ProcessedState
from sqlspec.data_dictionary import ForeignKeyMetadata, VersionCacheResult, VersionInfo, get_data_dictionary_loader
from sqlspec.data_dictionary._registry import get_dialect_config
from sqlspec.driver._query_cache import STMT_CACHE_MAX_SIZE, CachedQuery, QueryCache
from sqlspec.driver._storage_helpers import (
CAPABILITY_HINTS,
arrow_table_needs_parameter_preparation,
arrow_table_to_rows,
attach_partition_telemetry,
build_ingest_telemetry,
coerce_arrow_table,
create_storage_job,
)
from sqlspec.exceptions import (
ImproperConfigurationError,
NotFoundError,
SQLFileNotFoundError,
SQLSpecError,
StorageCapabilityError,
)
from sqlspec.observability import ObservabilityRuntime, get_trace_context, resolve_db_system
from sqlspec.protocols import HasDataProtocol, HasExecuteProtocol, StatementProtocol
from sqlspec.utils.dispatch import TypeDispatcher
from sqlspec.utils.logging import get_logger, log_with_context
from sqlspec.utils.schema import to_schema as _to_schema_impl
from sqlspec.utils.text import normalize_identifier
from sqlspec.utils.type_guards import (
has_array_interface,
has_asdict_method,
has_cursor_metadata,
has_dtype_str,
has_statement_type,
has_typecode,
has_typecode_and_len,
is_dict_row,
is_mapping_like,
is_statement_filter,
)
if TYPE_CHECKING:
from collections import abc
from collections.abc import Awaitable, Callable
from types import TracebackType
from sqlspec.core import ArrowResult, FilterTypeT, StatementFilter
from sqlspec.core.parameters._types import ConvertedParameters
from sqlspec.core.stack import StatementStack
from sqlspec.data_dictionary._types import DialectConfig
from sqlspec.storage import (
AsyncStoragePipeline,
StorageBridgeJob,
StorageCapabilities,
StorageTelemetry,
SyncStoragePipeline,
)
from sqlspec.typing import ArrowTable, SchemaT, StatementParameters
__all__ = (
"DEFAULT_EXECUTION_RESULT",
"EXEC_CURSOR_RESULT",
"EXEC_ROWCOUNT_OVERRIDE",
"EXEC_SPECIAL_DATA",
"VERSION_GROUPS_MIN_FOR_MINOR",
"VERSION_GROUPS_MIN_FOR_PATCH",
"AsyncExceptionHandler",
"CachedQuery",
"CommonDriverAttributesMixin",
"DataDictionaryDialectMixin",
"DataDictionaryMixin",
"ExecutionResult",
"ScriptExecutionResult",
"StackExecutionObserver",
"SyncExceptionHandler",
"describe_stack_statement",
"handle_single_row_error",
"hash_stack_operations",
"make_cache_key_hashable",
"resolve_db_system",
)
logger = get_logger("sqlspec.driver")
VERSION_GROUPS_MIN_FOR_MINOR = 1
VERSION_GROUPS_MIN_FOR_PATCH = 2
EXEC_CURSOR_RESULT: Final[int] = 0
EXEC_ROWCOUNT_OVERRIDE: Final[int] = 1
EXEC_SPECIAL_DATA: Final[int] = 2
DEFAULT_EXECUTION_RESULT: Final["tuple[object | None, int | None, object | None]"] = (None, None, None)
_DEFAULT_DML_METADATA: Final = {"status_message": "OK"}
_EMPTY_DML_DATA: Final[tuple[()]] = ()
_CONVERT_TO_TUPLE = object()
_CONVERT_TO_FROZENSET = object()
_TYPE_COERCION_DISPATCHERS: "dict[tuple[tuple[type, Any], ...], TypeDispatcher[Any]]" = {}
_CACHED_NAMED_STYLES: Final[frozenset[str]] = frozenset((
ParameterStyle.NAMED_COLON.value,
ParameterStyle.NAMED_AT.value,
ParameterStyle.NAMED_DOLLAR.value,
ParameterStyle.NAMED_PYFORMAT.value,
))
class SyncExceptionHandler(Protocol):
"""Protocol for synchronous exception handlers with deferred exception pattern.
Exception handlers implement this protocol to avoid ABI boundary violations
with mypyc-compiled code. Instead of raising exceptions from __exit__,
handlers store mapped exceptions in pending_exception for the caller to raise.
"""
@property
def pending_exception(self) -> Exception | None: ...
def __enter__(self) -> Self: ...
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None"
) -> bool: ...
class AsyncExceptionHandler(Protocol):
"""Protocol for asynchronous exception handlers with deferred exception pattern.
Exception handlers implement this protocol to avoid ABI boundary violations
with mypyc-compiled code. Instead of raising exceptions from __aexit__,
handlers store mapped exceptions in pending_exception for the caller to raise.
"""
@property
def pending_exception(self) -> Exception | None: ...
async def __aenter__(self) -> Self: ...
async def __aexit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None"
) -> bool: ...
class ScriptExecutionResult(NamedTuple):
"""Result from script execution with statement count information."""
cursor_result: Any
rowcount_override: int | None
special_data: Any
statement_count: int
successful_statements: int
class ExecutionResult(NamedTuple):
"""Execution result containing all data needed for SQLResult building."""
cursor_result: Any
rowcount_override: int | None
special_data: Any
selected_data: "list[Any] | None"
column_names: "list[str] | None"
data_row_count: int | None
statement_count: int | None
successful_statements: int | None
is_script_result: bool
is_select_result: bool
is_many_result: bool
row_format: str = "dict"
last_inserted_id: int | str | None = None
def make_cache_key_hashable(obj: Any) -> Any:
"""Recursively convert unhashable types to hashable ones for cache keys.
Uses an iterative stack-based approach to avoid C-stack recursion limits
in mypyc-compiled code.
For array-like objects (NumPy arrays, Python arrays, etc.), we use structural
info (dtype + shape or typecode + length) rather than content for cache keys.
Collections are processed with stack entries that track (object, parent_list, index)
so we can convert substructures in-place and then replace placeholders with tuples or frozensets
only after their children are evaluated. Dictionaries are iterated in sorted order for determinism
while sets fall back to a best-effort ordering if necessary.
Args:
obj: Object to make hashable.
Returns:
A hashable representation of the object. Collections become tuples,
arrays become structural tuples like ("ndarray", dtype, shape).
"""
if isinstance(obj, (int, str, bytes, bool, float, type(None))):
return obj
root: list[Any] = [obj]
stack = [(obj, root, 0)]
while stack:
current_obj, parent, idx = stack.pop()
if current_obj is _CONVERT_TO_TUPLE:
parent[idx] = tuple(parent[idx])
continue
if current_obj is _CONVERT_TO_FROZENSET:
parent[idx] = frozenset(parent[idx])
continue
if has_typecode_and_len(current_obj):
parent[idx] = ("array", current_obj.typecode, len(current_obj))
continue
if has_typecode(current_obj):
parent[idx] = ("array", current_obj.typecode)
continue
if has_array_interface(current_obj):
try:
dtype_str = current_obj.dtype.str if has_dtype_str(current_obj.dtype) else str(type(current_obj))
shape = tuple(int(s) for s in current_obj.shape)
parent[idx] = ("ndarray", dtype_str, shape)
except (AttributeError, TypeError):
try:
length = len(current_obj)
parent[idx] = ("array_like", type(current_obj).__name__, length)
except (AttributeError, TypeError):
parent[idx] = ("array_like", type(current_obj).__name__)
continue
if isinstance(current_obj, (list, tuple)):
new_list = [None] * len(current_obj)
parent[idx] = new_list
stack.append((_CONVERT_TO_TUPLE, parent, idx))
stack.extend((current_obj[i], new_list, i) for i in range(len(current_obj) - 1, -1, -1))
continue
if isinstance(current_obj, dict):
try:
sorted_items = sorted(current_obj.items())
except TypeError:
sorted_items = list(current_obj.items())
items_list = []
for k, v in sorted_items:
items_list.append([k, v])
parent[idx] = items_list
stack.append((_CONVERT_TO_TUPLE, parent, idx))
for i in range(len(items_list) - 1, -1, -1):
stack.extend(((_CONVERT_TO_TUPLE, items_list, i), (items_list[i][1], items_list[i], 1)))
continue
if isinstance(current_obj, set):
try:
sorted_list = sorted(current_obj)
except TypeError:
sorted_list = list(current_obj)
new_list = [None] * len(sorted_list)
parent[idx] = new_list
stack.append((_CONVERT_TO_FROZENSET, parent, idx))
stack.extend((sorted_list[i], new_list, i) for i in range(len(sorted_list) - 1, -1, -1))
continue
parent[idx] = current_obj
return root[0]
def describe_stack_statement(statement: "StatementProtocol | str") -> str:
"""Return a readable representation of a stack statement for diagnostics."""
if isinstance(statement, str):
return statement
if isinstance(statement, StatementProtocol):
return statement.raw_sql or repr(statement)
return repr(statement)
def handle_single_row_error(error: ValueError) -> "NoReturn":
"""Normalize single-row selection errors to SQLSpec exceptions."""
message = str(error)
if message.startswith("No result found"):
msg = "No rows found"
raise NotFoundError(msg) from error
raise error
def hash_stack_operations(stack: "StatementStack") -> "tuple[str, ...]":
"""Return SHA256 fingerprints for statements contained in the stack."""
hashes: list[str] = []
for operation in stack.operations:
summary = describe_stack_statement(operation.statement)
if not isinstance(summary, str):
summary = str(summary)
digest = hashlib.sha256(summary.encode("utf-8")).hexdigest()
hashes.append(digest[:16])
return tuple(hashes)
def _apply_declared_optional_defaults(declared: "tuple[ParameterDeclaration, ...]", supplied: "dict[str, Any]") -> None:
"""Bind missing optional named params as SQL NULL."""
for declaration in declared:
if not declaration.required and declaration.name not in supplied:
supplied[declaration.name] = None
def _check_declared_named_row(declared: "tuple[ParameterDeclaration, ...]", supplied: "dict[str, Any]") -> None:
"""Validate a single named-parameter mapping against declared params.
Required params must be present. Missing optional params are bound as
``None`` so SQL receives ``NULL``. A present non-``None`` value whose
declared type resolves via the registry must satisfy that matcher.
Unresolved types are documentation-only. Extra keys are ignored.
"""
_apply_declared_optional_defaults(declared, supplied)
for declaration in declared:
name = declaration.name
if name not in supplied:
msg = f"Missing required parameter '{name}' for declared SQL statement."
raise SQLSpecError(msg)
value = supplied[name]
if value is None:
continue
if not matches_param_type(declaration.type_str, value):
msg = f"Parameter '{name}' expected type '{declaration.type_str}' but got {type(value).__name__}."
raise SQLSpecError(msg)
def _validate_declared_parameters(sql_statement: "SQL") -> None:
"""Enforce declared-parameter contracts on the original user params.
No-op unless the statement carries declarations and is not a script. Runs before
driver style conversion, so declared names and raw values are intact. Named binding
is validated for presence and type; ``execute_many`` checks the first row only;
positional binding is skipped (arity is validated at load time).
"""
declared = sql_statement.declared_parameters
if not declared or sql_statement.is_script:
return
if sql_statement.is_many:
rows = sql_statement.positional_parameters
if rows and isinstance(rows[0], dict):
for row in rows:
if isinstance(row, dict):
_apply_declared_optional_defaults(declared, row)
_check_declared_named_row(declared, rows[0])
return
if sql_statement.positional_parameters:
return
_check_declared_named_row(declared, sql_statement.named_parameters)
class StackExecutionObserver:
"""Context manager that aggregates telemetry for stack execution."""
__slots__ = (
"continue_on_error",
"driver",
"hashed_operations",
"metrics",
"native_pipeline",
"runtime",
"span",
"stack",
"started",
)
def __init__(
self,
driver: "CommonDriverAttributesMixin",
stack: "StatementStack",
continue_on_error: bool,
native_pipeline: bool,
) -> None:
self.driver = driver
self.stack = stack
self.continue_on_error = continue_on_error
self.native_pipeline = native_pipeline
self.runtime = driver.observability
self.metrics = StackExecutionMetrics(
adapter=type(driver).__name__,
statement_count=len(stack.operations),
continue_on_error=continue_on_error,
native_pipeline=native_pipeline,
forced_disable=driver.stack_native_disabled,
)
self.hashed_operations = hash_stack_operations(stack)
self.span: Any | None = None
self.started = 0.0
def __enter__(self) -> Self:
self.started = perf_counter()
trace_id, span_id = get_trace_context()
attributes = {
"sqlspec.stack.statement_count": len(self.stack.operations),
"sqlspec.stack.continue_on_error": self.continue_on_error,
"sqlspec.stack.native_pipeline": self.native_pipeline,
"sqlspec.stack.forced_disable": self.driver.stack_native_disabled,
}
self.span = self.runtime.start_span("sqlspec.stack.execute", attributes=attributes)
log_with_context(
logger,
logging.DEBUG,
"stack.execute.start",
driver=type(self.driver).__name__,
db_system=resolve_db_system(type(self.driver).__name__),
stack_size=len(self.stack.operations),
continue_on_error=self.continue_on_error,
native_pipeline=self.native_pipeline,
forced_disable=self.driver.stack_native_disabled,
hashed_operations=self.hashed_operations,
trace_id=trace_id,
span_id=span_id,
)
return self
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None"
) -> Literal[False]:
duration = perf_counter() - self.started
self.metrics.record_duration(duration)
if isinstance(exc_val, Exception):
self.metrics.record_error(exc_val)
self.runtime.span_manager.end_span(self.span, error=exc_val if exc_val is not None else None)
self.metrics.emit(self.runtime)
level = logging.ERROR if exc_val is not None else logging.DEBUG
trace_id, span_id = get_trace_context()
log_with_context(
logger,
level,
"stack.execute.failed" if exc_val is not None else "stack.execute.complete",
driver=type(self.driver).__name__,
db_system=resolve_db_system(type(self.driver).__name__),
stack_size=len(self.stack.operations),
continue_on_error=self.continue_on_error,
native_pipeline=self.native_pipeline,
forced_disable=self.driver.stack_native_disabled,
hashed_operations=self.hashed_operations,
duration_ms=duration * 1000,
error_type=type(exc_val).__name__ if exc_val is not None else None,
trace_id=trace_id,
span_id=span_id,
)
return False
def record_operation_error(self, error: Exception) -> None:
"""Record an operation error when continue-on-error is enabled."""
self.metrics.record_operation_error(error)
@mypyc_attr(allow_interpreted_subclasses=True)
@trait
class DataDictionaryDialectMixin:
"""Mixin providing dialect SQL helpers for data dictionaries."""
__slots__ = ()
dialect: "ClassVar[str]"
def get_dialect_config(self) -> "DialectConfig":
"""Return the dialect configuration for this data dictionary."""
return get_dialect_config(type(self).dialect)
def get_query(self, name: str) -> "SQL":
"""Return a named SQL query for this dialect."""
loader = get_data_dictionary_loader()
return loader.get_query(type(self).dialect, name)
def get_query_text(self, name: str) -> str:
"""Return raw SQL text for a named query for this dialect."""
loader = get_data_dictionary_loader()
return loader.get_query_text(type(self).dialect, name)
def get_query_text_or_none(self, name: str) -> "str | None":
"""Return raw SQL text for a named query or None if missing."""
try:
return self.get_query_text(name)
except SQLFileNotFoundError:
return None
def resolve_schema(self, schema: "str | None") -> "str | None":
"""Return a schema name using dialect defaults when missing."""
config = self.get_dialect_config()
if schema is not None:
return normalize_identifier(schema, config.name)
if config.default_schema is None:
return None
return normalize_identifier(config.default_schema, config.name)
def resolve_identifier(self, identifier: str) -> str:
"""Return a dialect-normalized identifier value."""
return normalize_identifier(identifier, self.get_dialect_config().name)
def resolve_feature_flag(self, feature: str, version: "VersionInfo | None") -> bool:
"""Resolve a feature flag using dialect config and version info."""
config = self.get_dialect_config()
flag = config.get_feature_flag(feature)
if flag is not None:
return flag
required_version = config.get_feature_version(feature)
if required_version is None or version is None:
return False
return bool(version >= required_version)
def get_default_features(self) -> "list[str]":
"""Get default feature flags. Overridden by DataDictionaryMixin."""
return []
def list_available_features(self) -> "list[str]":
"""List all features that can be checked via get_feature_flag.
Returns:
List of feature names this data dictionary supports
"""
config = self.get_dialect_config()
features = set(self.get_default_features())
features.update(config.feature_flags.keys())
features.update(config.feature_versions.keys())
return sorted(features)
@mypyc_attr(allow_interpreted_subclasses=True)
@trait
class DataDictionaryMixin:
"""Mixin providing common data dictionary functionality.
Includes version caching to avoid repeated database queries when checking
feature flags or optimal types.
"""
__slots__ = ()
dialect: "ClassVar[str]"
_version_cache: "dict[int, VersionInfo | None]"
_version_fetch_attempted: "set[int]"
def get_cached_version(self, driver_id: int) -> "VersionCacheResult":
"""Get cached version info for a driver.
Args:
driver_id: The id() of the driver instance.
Returns:
Tuple of (was_cached, version_info). If was_cached is False,
the caller should fetch the version and call cache_version().
"""
if driver_id in self._version_fetch_attempted:
return True, self._version_cache.get(driver_id)
return False, None