"""AsyncPG PostgreSQL driver implementation for async PostgreSQL operations."""
import re
from collections import OrderedDict
from io import BytesIO
from typing import TYPE_CHECKING, Any, Final, cast
import asyncpg
from sqlspec.adapters.asyncpg._typing import AsyncpgCursor, AsyncpgSessionContext
from sqlspec.adapters.asyncpg.core import (
PREPARED_STATEMENT_CACHE_SIZE,
AsyncpgStreamSource,
NormalizedStackOperation,
collect_rows,
create_mapped_exception,
default_statement_config,
driver_profile,
invoke_prepared_statement,
parse_status,
resolve_many_rowcount,
)
from sqlspec.adapters.asyncpg.data_dictionary import AsyncpgDataDictionary
from sqlspec.core import (
SQL,
StackResult,
StatementStack,
create_sql_result,
get_cache_config,
is_copy_from_operation,
is_copy_operation,
register_driver_profile,
)
from sqlspec.driver import (
AsyncDriverAdapterBase,
AsyncRowStream,
BaseAsyncExceptionHandler,
StackExecutionObserver,
describe_stack_statement,
)
from sqlspec.exceptions import SQLSpecError, StackExecutionError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.text import normalize_identifier, quote_identifier
from sqlspec.utils.type_guards import has_sqlstate
if TYPE_CHECKING:
from collections.abc import Sequence
from sqlspec.adapters.asyncpg._typing import AsyncpgConnection, AsyncpgPreparedStatement
from sqlspec.core import ArrowResult, SQLResult, StatementConfig
from sqlspec.driver import ExecutionResult
from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry
__all__ = ("AsyncpgCursor", "AsyncpgDriver", "AsyncpgExceptionHandler", "AsyncpgSessionContext")
_COPY_FROM_STDIN_RE: re.Pattern[str] = re.compile(
r'COPY\s+((?:"[^"]+"|\w+)(?:\.(?:"[^"]+"|\w+))?)(?:\s*\([^)]*\))?\s+FROM\s+STDIN', re.IGNORECASE
)
_QUALIFIED_TABLE_NAME_PARTS: Final = 2
logger = get_logger("sqlspec.adapters.asyncpg")
class AsyncpgExceptionHandler(BaseAsyncExceptionHandler):
"""Async context manager for handling AsyncPG database exceptions.
Maps PostgreSQL SQLSTATE error codes to specific SQLSpec exceptions
for better error handling in application code.
Uses deferred exception pattern for mypyc compatibility: exceptions
are stored in pending_exception rather than raised from __aexit__
to avoid ABI boundary violations with compiled code.
"""
__slots__ = ()
def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
_ = exc_type
if isinstance(exc_val, asyncpg.PostgresError) or has_sqlstate(exc_val):
self.pending_exception = create_mapped_exception(exc_val)
return True
return False
class AsyncpgDriver(AsyncDriverAdapterBase):
"""AsyncPG PostgreSQL driver for async database operations.
Supports COPY operations, numeric parameter style handling, PostgreSQL
exception handling, transaction management, SQL statement compilation
and caching, and parameter processing with type coercion.
"""
__slots__ = ("_data_dictionary", "_prepared_statements")
dialect = "postgres"
def __init__(
self,
connection: "AsyncpgConnection",
statement_config: "StatementConfig | None" = None,
driver_features: "dict[str, Any] | None" = None,
) -> None:
if statement_config is None:
statement_config = default_statement_config.replace(
enable_caching=get_cache_config().compiled_cache_enabled
)
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
self._data_dictionary: AsyncpgDataDictionary | None = None
self._prepared_statements: OrderedDict[str, AsyncpgPreparedStatement] = OrderedDict()