Source code for sqlspec.adapters.asyncpg.driver

"""AsyncPG PostgreSQL driver implementation for async PostgreSQL operations."""

import re
from collections import OrderedDict
from io import BytesIO
from typing import TYPE_CHECKING, Any, Final, cast

import asyncpg

from sqlspec.adapters.asyncpg._typing import AsyncpgCursor, AsyncpgSessionContext
from sqlspec.adapters.asyncpg.core import (
    PREPARED_STATEMENT_CACHE_SIZE,
    AsyncpgStreamSource,
    NormalizedStackOperation,
    collect_rows,
    create_mapped_exception,
    default_statement_config,
    driver_profile,
    invoke_prepared_statement,
    parse_status,
    resolve_many_rowcount,
)
from sqlspec.adapters.asyncpg.data_dictionary import AsyncpgDataDictionary
from sqlspec.core import (
    SQL,
    StackResult,
    StatementStack,
    create_sql_result,
    get_cache_config,
    is_copy_from_operation,
    is_copy_operation,
    register_driver_profile,
)
from sqlspec.driver import (
    AsyncDriverAdapterBase,
    AsyncRowStream,
    BaseAsyncExceptionHandler,
    StackExecutionObserver,
    describe_stack_statement,
)
from sqlspec.exceptions import SQLSpecError, StackExecutionError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.text import normalize_identifier, quote_identifier
from sqlspec.utils.type_guards import has_sqlstate

if TYPE_CHECKING:
    from collections.abc import Sequence

    from sqlspec.adapters.asyncpg._typing import AsyncpgConnection, AsyncpgPreparedStatement
    from sqlspec.core import ArrowResult, SQLResult, StatementConfig
    from sqlspec.driver import ExecutionResult
    from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry


__all__ = ("AsyncpgCursor", "AsyncpgDriver", "AsyncpgExceptionHandler", "AsyncpgSessionContext")

_COPY_FROM_STDIN_RE: re.Pattern[str] = re.compile(
    r'COPY\s+((?:"[^"]+"|\w+)(?:\.(?:"[^"]+"|\w+))?)(?:\s*\([^)]*\))?\s+FROM\s+STDIN', re.IGNORECASE
)
_QUALIFIED_TABLE_NAME_PARTS: Final = 2

logger = get_logger("sqlspec.adapters.asyncpg")


class AsyncpgExceptionHandler(BaseAsyncExceptionHandler):
    """Async context manager for handling AsyncPG database exceptions.

    Maps PostgreSQL SQLSTATE error codes to specific SQLSpec exceptions
    for better error handling in application code.

    Uses deferred exception pattern for mypyc compatibility: exceptions
    are stored in pending_exception rather than raised from __aexit__
    to avoid ABI boundary violations with compiled code.
    """

    __slots__ = ()

    def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
        _ = exc_type
        if isinstance(exc_val, asyncpg.PostgresError) or has_sqlstate(exc_val):
            self.pending_exception = create_mapped_exception(exc_val)
            return True
        return False


class AsyncpgDriver(AsyncDriverAdapterBase):
    """AsyncPG PostgreSQL driver for async database operations.

    Supports COPY operations, numeric parameter style handling, PostgreSQL
    exception handling, transaction management, SQL statement compilation
    and caching, and parameter processing with type coercion.
    """

    __slots__ = ("_data_dictionary", "_prepared_statements")
    dialect = "postgres"

    def __init__(
        self,
        connection: "AsyncpgConnection",
        statement_config: "StatementConfig | None" = None,
        driver_features: "dict[str, Any] | None" = None,
    ) -> None:
        if statement_config is None:
            statement_config = default_statement_config.replace(
                enable_caching=get_cache_config().compiled_cache_enabled
            )

        super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
        self._data_dictionary: AsyncpgDataDictionary | None = None
        self._prepared_statements: OrderedDict[str, AsyncpgPreparedStatement] = OrderedDict()
# ───────────────────────────────────────────────────────────────────────────── # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── async def dispatch_execute(self, cursor: "AsyncpgConnection", statement: "SQL") -> "ExecutionResult": """Execute single SQL statement. Handles both SELECT queries and non-SELECT operations. Args: cursor: AsyncPG connection object statement: SQL statement to execute Returns: ExecutionResult with statement execution details """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) params: tuple[Any, ...] = cast("tuple[Any, ...]", prepared_parameters) if prepared_parameters else () if statement.returns_rows(): records = await cursor.fetch(sql, *params) if params else await cursor.fetch(sql) data, column_names = collect_rows(records) return self.create_execution_result( cursor, selected_data=data, column_names=column_names, data_row_count=len(data), is_select_result=True, row_format="record", ) result = await cursor.execute(sql, *params) if params else await cursor.execute(sql) affected_rows = parse_status(result) return self.create_execution_result(cursor, rowcount_override=affected_rows)