"""SQLite database configuration with thread-local connections."""
import contextlib
import logging
import sqlite3
import threading
import time
import uuid
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, cast
from sqlspec.adapters.sqlite._typing import SqliteConnection
from sqlspec.utils.logging import POOL_LOGGER_NAME, get_logger, log_with_context
if TYPE_CHECKING:
from collections.abc import Callable, Generator
__all__ = ("SqliteConnectionPool",)
logger = get_logger(POOL_LOGGER_NAME)
_ADAPTER_NAME = "sqlite"
def _dict_row_factory(cursor: Any, row: "tuple[Any, ...]") -> "dict[str, Any]":
return {column[0]: row[index] for index, column in enumerate(cursor.description)}
def _resolve_row_factory(row_factory: Any) -> Any:
if row_factory == "row":
return sqlite3.Row
if row_factory == "dict":
return _dict_row_factory
if row_factory == "tuple":
return None
return row_factory
def _load_extensions(connection: SqliteConnection, extensions: "list[str]") -> None:
connection.enable_load_extension(True)
try:
for extension_path in extensions:
connection.load_extension(extension_path)
finally:
connection.enable_load_extension(False)
def _apply_runtime_setup(connection: SqliteConnection, runtime_setup: "dict[str, Any]") -> None:
for pragma_name, pragma_value in runtime_setup.get("pragmas", ()):
connection.execute(f"PRAGMA {pragma_name} = {pragma_value}")
extensions = runtime_setup.get("extensions")
if extensions:
_load_extensions(connection, list(extensions))
for function_config in runtime_setup.get("custom_functions", ()):
connection.create_function(
function_config["name"],
function_config["narg"],
function_config["func"],
deterministic=function_config.get("deterministic", False),
)
for aggregate_config in runtime_setup.get("custom_aggregates", ()):
connection.create_aggregate(
aggregate_config["name"], aggregate_config["narg"], aggregate_config["aggregate_class"]
)
for collation_config in runtime_setup.get("custom_collations", ()):
connection.create_collation(collation_config["name"], collation_config["func"])
authorizer_callback = runtime_setup.get("authorizer_callback")
if authorizer_callback is not None:
connection.set_authorizer(authorizer_callback)
trace_callback = runtime_setup.get("trace_callback")
if trace_callback is not None:
connection.set_trace_callback(trace_callback)
progress_handler = runtime_setup.get("progress_handler")
if progress_handler is not None:
connection.set_progress_handler(progress_handler, runtime_setup.get("progress_handler_interval", 1000))
if "row_factory" in runtime_setup:
connection.row_factory = _resolve_row_factory(runtime_setup["row_factory"])
if "text_factory" in runtime_setup:
connection.text_factory = runtime_setup["text_factory"]
class SqliteConnectionPool:
"""Thread-local connection manager for SQLite.
SQLite connections aren't thread-safe, so we use thread-local storage
to ensure each thread has its own connection. This is simpler and more
efficient than a traditional pool for SQLite's constraints.
"""
__slots__ = (
"_connection_parameters",
"_enable_optimizations",
"_health_check_interval",
"_on_connection_create",
"_pool_id",
"_recycle_seconds",
"_runtime_setup",
"_thread_local",
)
def __init__(
self,
connection_parameters: "dict[str, Any]",
enable_optimizations: bool = True,
recycle_seconds: int = 86400,
health_check_interval: float = 30.0,
on_connection_create: "Callable[[SqliteConnection], None] | None" = None,
runtime_setup: "dict[str, Any] | None" = None,
) -> None:
"""Initialize the thread-local connection manager.
Args:
connection_parameters: SQLite connection parameters
enable_optimizations: Whether to apply performance PRAGMAs
recycle_seconds: Connection recycle time in seconds (default 24h)
health_check_interval: Seconds of idle time before running health check
on_connection_create: Callback executed when connection is created
runtime_setup: Runtime feature configuration applied after internal PRAGMAs
"""
if "check_same_thread" not in connection_parameters:
connection_parameters = {**connection_parameters, "check_same_thread": False}
self._connection_parameters = connection_parameters
self._thread_local = threading.local()
self._enable_optimizations = enable_optimizations
self._recycle_seconds = recycle_seconds
self._health_check_interval = health_check_interval
self._on_connection_create = on_connection_create
self._runtime_setup = runtime_setup
self._pool_id = str(uuid.uuid4())[:8]