Source code for sqlspec.adapters.sqlite.driver

"""SQLite driver implementation."""

import sqlite3
from typing import TYPE_CHECKING, Any, cast

from sqlspec.adapters.sqlite._typing import SqliteCursor, SqliteSessionContext
from sqlspec.adapters.sqlite.core import (
    SqliteStreamSource,
    build_insert_statement,
    collect_rows,
    create_mapped_exception,
    default_statement_config,
    driver_profile,
    format_identifier,
    normalize_execute_many_parameters,
    normalize_execute_parameters,
    resolve_rowcount,
)
from sqlspec.adapters.sqlite.data_dictionary import SqliteDataDictionary
from sqlspec.core import ArrowResult, ParameterStyle, TypedParameter, get_cache_config, register_driver_profile
from sqlspec.core.result import DMLResult
from sqlspec.driver import BaseSyncExceptionHandler, SyncDriverAdapterBase, SyncRowStream
from sqlspec.exceptions import SQLSpecError

if TYPE_CHECKING:
    from collections.abc import Sequence

    from sqlspec.adapters.sqlite._typing import SqliteConnection
    from sqlspec.builder import QueryBuilder
    from sqlspec.core import SQL, SQLResult, Statement, StatementConfig, StatementFilter
    from sqlspec.core.compiler import OperationType
    from sqlspec.driver import ExecutionResult
    from sqlspec.driver._query_cache import CachedQuery
    from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry
    from sqlspec.typing import StatementParameters

__all__ = ("SqliteCursor", "SqliteDriver", "SqliteExceptionHandler", "SqliteSessionContext")


class SqliteExceptionHandler(BaseSyncExceptionHandler):
    """Context manager for handling SQLite database exceptions.

    Maps SQLite extended result codes to specific SQLSpec exceptions
    for better error handling in application code.

    Uses deferred exception pattern for mypyc compatibility: exceptions
    are stored in pending_exception rather than raised from __exit__
    to avoid ABI boundary violations with compiled code.
    """

    def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
        if exc_type is None:
            return False
        if issubclass(exc_type, sqlite3.Error):
            self.pending_exception = create_mapped_exception(exc_val)
            return True
        return False


class SqliteDriver(SyncDriverAdapterBase):
    """SQLite driver implementation.

    Provides SQL statement execution, transaction management, and result handling
    for SQLite databases using the standard sqlite3 module.
    """

    __slots__ = ("_data_dictionary",)
    dialect = "sqlite"

    def __init__(
        self,
        connection: "SqliteConnection",
        statement_config: "StatementConfig | None" = None,
        driver_features: "dict[str, Any] | None" = None,
    ) -> None:
        """Initialize SQLite driver.

        Args:
            connection: SQLite database connection
            statement_config: Statement configuration settings
            driver_features: Driver-specific feature flags
        """
        if statement_config is None:
            statement_config = default_statement_config.replace(
                enable_caching=get_cache_config().compiled_cache_enabled
            )

        super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
        self._data_dictionary: SqliteDataDictionary | None = None
# ───────────────────────────────────────────────────────────────────────────── # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute single SQL statement. Args: cursor: SQLite cursor object statement: SQL statement to execute Returns: ExecutionResult with statement execution details """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) cursor.execute(sql, normalize_execute_parameters(prepared_parameters)) if statement.returns_rows(): fetched_data = cursor.fetchall() data, column_names, row_count = collect_rows(fetched_data, cursor.description) return self.create_execution_result( cursor, selected_data=data, column_names=column_names, data_row_count=row_count, is_select_result=True, row_format="tuple", ) affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows)