Source code for sqlspec.driver._async

"""Asynchronous driver protocol implementation."""

import logging
from abc import abstractmethod
from time import perf_counter
from typing import TYPE_CHECKING, Any, ClassVar, Final, cast, final, overload

from mypy_extensions import mypyc_attr

from sqlspec.core import SQL, StackResult, create_arrow_result
from sqlspec.core.result import DMLResult
from sqlspec.core.stack import StackOperation, StatementStack
from sqlspec.driver._common import (
    AsyncExceptionHandler,
    CommonDriverAttributesMixin,
    DataDictionaryDialectMixin,
    DataDictionaryMixin,
    ExecutionResult,
    StackExecutionObserver,
    _raise_database_exception,
    describe_stack_statement,
    handle_single_row_error,
)
from sqlspec.driver._query_cache import CachedQuery
from sqlspec.driver._sql_helpers import DEFAULT_PRETTY
from sqlspec.driver._sql_helpers import convert_to_dialect as _convert_to_dialect_impl
from sqlspec.driver._storage_helpers import stringify_storage_target
from sqlspec.driver._stream import AsyncRowStream, _LazyEagerAsyncRowSource
from sqlspec.exceptions import ImproperConfigurationError, StackExecutionError
from sqlspec.observability import _runtime as observability_runtime
from sqlspec.storage import AsyncStoragePipeline, StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry
from sqlspec.utils.arrow_helpers import convert_dict_to_arrow_with_schema
from sqlspec.utils.logging import get_logger, log_with_context
from sqlspec.utils.schema import ValueT, to_value_type

if TYPE_CHECKING:
    from collections.abc import Awaitable, Sequence

    from sqlglot.dialects.dialect import DialectType

    from sqlspec.builder import QueryBuilder
    from sqlspec.core import ArrowResult, SQLResult, Statement, StatementConfig, StatementFilter
    from sqlspec.data_dictionary import ColumnMetadata, ForeignKeyMetadata, IndexMetadata, TableMetadata, VersionInfo
    from sqlspec.typing import ArrowReturnFormat, ArrowTable, SchemaT, StatementParameters


__all__ = ("AsyncDataDictionaryBase", "AsyncDriverAdapterBase", "AsyncPoolConnectionContext", "AsyncPoolSessionFactory")


_LOGGER_NAME: Final[str] = "sqlspec.driver"
logger = get_logger(_LOGGER_NAME)


@mypyc_attr(allow_interpreted_subclasses=True)
class AsyncPoolConnectionContext:
    """Base async connection context using pool acquire/release pattern.

    Subclass per adapter and override ``__aenter__``/``__aexit__`` for
    adapter-specific pool acquisition and release logic.
    """

    __slots__ = ("_config", "_connection")

    def __init__(self, config: Any) -> None:
        self._config = config
        self._connection: Any = None

    async def __aenter__(self) -> Any:
        pool = self._config.connection_instance
        if pool is None:
            pool = await self._config.create_pool()
            self._config.connection_instance = pool
        self._connection = await pool.acquire()
        return self._connection

    async def __aexit__(
        self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any
    ) -> "bool | None":
        if self._connection is not None:
            if self._config.connection_instance:
                await self._config.connection_instance.release(self._connection)
            self._connection = None
        return None


@mypyc_attr(allow_interpreted_subclasses=True)
class AsyncPoolSessionFactory:
    """Base async session factory using pool acquire/release pattern.

    Subclass per adapter and override ``acquire_connection``/``release_connection``
    for adapter-specific pool acquisition and release logic.
    """

    __slots__ = ("_config", "_connection")

    def __init__(self, config: Any) -> None:
        self._config = config
        self._connection: Any = None

    async def acquire_connection(self) -> Any:
        pool = self._config.connection_instance
        if pool is None:
            pool = await self._config.create_pool()
            self._config.connection_instance = pool
        self._connection = await pool.acquire()
        return self._connection

    async def release_connection(self, _conn: Any, **kwargs: Any) -> None:
        if self._connection is not None and self._config.connection_instance is not None:
            await self._config.connection_instance.release(self._connection)
            self._connection = None


@mypyc_attr(allow_interpreted_subclasses=True)
class AsyncDriverAdapterBase(CommonDriverAttributesMixin):
    """Base class for asynchronous database drivers.

    This class includes flattened storage and SQL translation methods that were
    previously in StorageDriverMixin and SQLTranslatorMixin. The flattening
    eliminates cross-trait attribute access that caused mypyc segmentation faults.

    Method Organization:
        1. Core dispatch methods (the execution engine)
        2. Transaction management (abstract methods)
        3. Public API - execution methods
        4. Public API - query methods (select/fetch variants)
        5. Arrow API methods
        6. Stack execution
        7. Storage API methods
        8. Utility methods
        9. Private/internal methods
    """

    __slots__ = ()

    dialect: "DialectType | None" = None

    @property
    def is_async(self) -> bool:
        """Return whether the driver executes asynchronously.

        Returns:
            True for async drivers.
        """
        return True

    @property
    @abstractmethod
    def data_dictionary(self) -> "AsyncDataDictionaryBase":
        """Get the data dictionary for this driver.

        Returns:
            Data dictionary instance for metadata queries
        """

    async def set_migration_session_schema(self, schema: str) -> None:
        """Set the default schema for migration SQL when supported.

        Args:
            schema: Schema requested for the current migration session.
        """
        log_with_context(logger, logging.DEBUG, "migration.schema.noop", schema=schema, driver=type(self).__name__)
async def set_migration_non_transactional_schema(self, schema: str) -> None: """Set the default schema for non-transactional migration SQL when supported. Args: schema: Schema requested for the current migration session. """ await self.set_migration_session_schema(schema)