Source code for sqlspec.extensions.fastapi.extension

from typing import TYPE_CHECKING, Any, overload

from fastapi import FastAPI, Request

from sqlspec.base import SQLSpec
from sqlspec.extensions.fastapi.providers import DEPENDENCY_DEFAULTS
from sqlspec.extensions.fastapi.providers import provide_filters as _provide_filters
from sqlspec.extensions.starlette.extension import SQLSpecPlugin as _StarlettePlugin

if TYPE_CHECKING:
    from collections.abc import Callable

    from sqlspec.config import AsyncDatabaseConfig, SyncDatabaseConfig
    from sqlspec.core import FilterTypes
    from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
    from sqlspec.extensions.fastapi.providers import DependencyDefaults, FilterConfig

    # Type aliases for static analysis - IDEs see the real types
    _AsyncSession = AsyncDriverAdapterBase
    _SyncSession = SyncDriverAdapterBase
    _Session = AsyncDriverAdapterBase | SyncDriverAdapterBase
else:
    # Runtime fallback - FastAPI sees Any (avoids NameError)
    _AsyncSession = Any
    _SyncSession = Any
    _Session = Any

__all__ = ("SQLSpecPlugin",)


class SQLSpecPlugin(_StarlettePlugin):
    """SQLSpec integration for FastAPI applications.

    Extends Starlette integration with dependency injection helpers for FastAPI's
    Depends() system.
    """

    def __init__(self, sqlspec: SQLSpec, app: "FastAPI | None" = None) -> None:
        """Initialize SQLSpec FastAPI extension.

        Args:
            sqlspec: Pre-configured SQLSpec instance with registered configs.
            app: Optional FastAPI application to initialize immediately.
        """
        super().__init__(sqlspec, app)
def _extract_extension_settings(self, config: Any) -> "dict[str, Any]": """Extract FastAPI settings from config.extension_config. Args: config: Database configuration instance. Returns: Dictionary of FastAPI-specific settings. """ fastapi_config = config.extension_config.get("fastapi", {}) connection_key = fastapi_config.get("connection_key", "db_connection") pool_key = fastapi_config.get("pool_key", "db_pool") session_key = fastapi_config.get("session_key", "db_session") commit_mode = fastapi_config.get("commit_mode", "manual") if not config.supports_connection_pooling and pool_key == "db_pool": pool_key = f"_db_pool_{id(config)}" correlation_headers = fastapi_config.get("correlation_headers") if correlation_headers is not None: correlation_headers = tuple(correlation_headers) return { "connection_key": connection_key, "pool_key": pool_key, "session_key": session_key, "commit_mode": commit_mode, "extra_commit_statuses": fastapi_config.get("extra_commit_statuses"), "extra_rollback_statuses": fastapi_config.get("extra_rollback_statuses"), "disable_di": fastapi_config.get("disable_di", False), "enable_correlation_middleware": fastapi_config.get("enable_correlation_middleware", False), "correlation_header": fastapi_config.get("correlation_header", "x-request-id"), "correlation_headers": correlation_headers, "auto_trace_headers": fastapi_config.get("auto_trace_headers", True), "enable_sqlcommenter_middleware": fastapi_config.get("enable_sqlcommenter_middleware", True), "sqlcommenter_framework": fastapi_config.get("sqlcommenter_framework", "fastapi"), } @overload def provide_session( self, key: None = None ) -> "Callable[[Request], AsyncDriverAdapterBase | SyncDriverAdapterBase]": ... @overload def provide_session(self, key: str) -> "Callable[[Request], AsyncDriverAdapterBase | SyncDriverAdapterBase]": ... @overload def provide_session(self, key: "type[AsyncDatabaseConfig]") -> "Callable[[Request], AsyncDriverAdapterBase]": ... @overload def provide_session(self, key: "type[SyncDatabaseConfig]") -> "Callable[[Request], SyncDriverAdapterBase]": ... @overload def provide_session(self, key: "AsyncDatabaseConfig") -> "Callable[[Request], AsyncDriverAdapterBase]": ... @overload def provide_session(self, key: "SyncDatabaseConfig") -> "Callable[[Request], SyncDriverAdapterBase]": ... def provide_session( self, key: "str | type[AsyncDatabaseConfig | SyncDatabaseConfig] | AsyncDatabaseConfig | SyncDatabaseConfig | None" = None, ) -> "Callable[[Request], AsyncDriverAdapterBase | SyncDriverAdapterBase]": """Create dependency factory for session injection. Returns a callable that can be used with FastAPI's Depends() to inject a database session into route handlers. Args: key: Optional session key (str), config type for type narrowing, or None. Returns: Dependency callable for FastAPI Depends(). """ # Extract string key if provided, ignore config types/instances (used only for type narrowing) session_key = key if isinstance(key, str) or key is None else None def dependency(request: Request) -> _Session: return self.get_session(request, session_key) # type: ignore[no-any-return] return dependency