Source code for sqlspec.base

import asyncio
import atexit
import weakref
from collections.abc import Awaitable, Coroutine
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeGuard, cast, overload

from mypy_extensions import mypyc_attr
from typing_extensions import Self, TypeVar

from sqlspec.config import (
    AsyncConfigT,
    AsyncDatabaseConfig,
    DatabaseConfigProtocol,
    DriverT,
    NoPoolAsyncConfig,
    NoPoolSyncConfig,
    SyncConfigT,
    SyncDatabaseConfig,
)
from sqlspec.core import (
    CacheConfig,
    get_cache_config,
    get_cache_statistics,
    log_cache_stats,
    reset_stats_only,
    update_cache_config,
)
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.extensions.events import AsyncEventChannel, SyncEventChannel
from sqlspec.loader import SQLFileLoader
from sqlspec.observability import ObservabilityConfig, ObservabilityRuntime, TelemetryDiagnostics
from sqlspec.typing import ConnectionT
from sqlspec.utils.logging import get_logger
from sqlspec.utils.type_guards import has_name

if TYPE_CHECKING:
    from collections.abc import Sequence
    from pathlib import Path
    from types import TracebackType

    from sqlspec.core import SQL, ParameterDeclaration
    from sqlspec.typing import PoolT


__all__ = ("SQLSpec",)

logger = get_logger()


class SQLSpec:
    """Configuration manager and registry for database connections and pools."""

    __slots__ = ("__weakref__", "_configs", "_loader", "_loader_runtime", "_observability_config")

    _live_instances: "ClassVar[weakref.WeakSet[SQLSpec]]" = weakref.WeakSet()
    _atexit_registered: ClassVar[bool] = False

    def __init__(
        self, *, loader: "SQLFileLoader | None" = None, observability_config: "ObservabilityConfig | None" = None
    ) -> None:
        self._configs: dict[int, DatabaseConfigProtocol[Any, Any, Any]] = {}
        SQLSpec._live_instances.add(self)
        if not SQLSpec._atexit_registered:
            atexit.register(SQLSpec._cleanup_all_sync_pools)
            SQLSpec._atexit_registered = True
        self._observability_config = observability_config
        self._loader_runtime = ObservabilityRuntime(observability_config, config_name="SQLFileLoader")
        self._loader = loader
        if self._loader is not None:
            self._loader.set_observability_runtime(self._loader_runtime)
@overload def add_config(self, config: "SyncConfigT") -> "SyncConfigT": ... @overload def add_config(self, config: "AsyncConfigT") -> "AsyncConfigT": ... def add_config(self, config: "SyncConfigT | AsyncConfigT") -> "SyncConfigT | AsyncConfigT": """Add a configuration instance to the registry. Args: config: The configuration instance to add. Returns: The same configuration instance (it IS the handle). """ config_id = id(config) if config_id in self._configs: logger.debug("Configuration for %s already exists. Overwriting.", config.__class__.__name__) config.attach_observability(self._observability_config) self._configs[config_id] = config return config