Source code for sqlspec.adapters.mssql_python.driver

"""mssql-python sync and async drivers."""

import asyncio
from typing import TYPE_CHECKING, Any, TypedDict, cast

from typing_extensions import NotRequired

from sqlspec.adapters.mssql_python._typing import (
    MSSQL_PYTHON_MODULE,
    MssqlPythonAsyncCursor,
    MssqlPythonAsyncSessionContext,
    MssqlPythonConnection,
    MssqlPythonCursor,
    MssqlPythonRawCursor,
    MssqlPythonSessionContext,
)
from sqlspec.adapters.mssql_python.core import create_mapped_exception, default_statement_config, driver_profile
from sqlspec.adapters.mssql_python.data_dictionary import MssqlPythonAsyncDataDictionary, MssqlPythonSyncDataDictionary
from sqlspec.core import (
    build_arrow_result_from_reader,
    build_arrow_result_from_table,
    get_cache_config,
    register_driver_profile,
)
from sqlspec.driver import (
    AsyncDriverAdapterBase,
    BaseAsyncExceptionHandler,
    BaseSyncExceptionHandler,
    SyncDriverAdapterBase,
)
from sqlspec.exceptions import SQLSpecError
from sqlspec.utils.arrow_helpers import arrow_reader_with_deferred_close
from sqlspec.utils.logging import get_logger
from sqlspec.utils.module_loader import ensure_pyarrow

if TYPE_CHECKING:
    from collections.abc import Iterable

    from sqlspec.builder import QueryBuilder
    from sqlspec.core import SQL, ArrowResult, Statement, StatementConfig, StatementFilter
    from sqlspec.driver import ExecutionResult
    from sqlspec.typing import ArrowRecordBatchReader, ArrowReturnFormat, StatementParameters

__all__ = (
    "MssqlPythonAsyncCursor",
    "MssqlPythonAsyncDriver",
    "MssqlPythonAsyncExceptionHandler",
    "MssqlPythonAsyncSessionContext",
    "MssqlPythonBulkCopyResult",
    "MssqlPythonCursor",
    "MssqlPythonDriver",
    "MssqlPythonExceptionHandler",
    "MssqlPythonSessionContext",
)

logger = get_logger("sqlspec.adapters.mssql_python")
_MSSQL_ERROR = cast("type[BaseException]", getattr(MSSQL_PYTHON_MODULE, "Error", Exception))


class MssqlPythonBulkCopyResult(TypedDict):
    """BulkCopy statistics returned by mssql-python."""

    rows_copied: int
    batch_count: NotRequired[int]
    elapsed_time: NotRequired[float]


class MssqlPythonExceptionHandler(BaseSyncExceptionHandler):
    """Sync context manager handling mssql-python exceptions."""

    __slots__ = ()

    def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
        if exc_type is None:
            return False
        if isinstance(exc_val, _MSSQL_ERROR):
            self.pending_exception = create_mapped_exception(cast("Exception", exc_val), logger=logger)
            return True
        return False


class MssqlPythonAsyncExceptionHandler(BaseAsyncExceptionHandler):
    """Async context manager handling mssql-python exceptions."""

    __slots__ = ()

    def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
        if exc_type is None:
            return False
        if isinstance(exc_val, _MSSQL_ERROR):
            self.pending_exception = create_mapped_exception(cast("Exception", exc_val), logger=logger)
            return True
        return False


class MssqlPythonDriver(SyncDriverAdapterBase):
    """mssql-python sync driver."""

    __slots__ = ("_data_dictionary",)
    dialect = "tsql"

    def __init__(
        self,
        connection: "MssqlPythonConnection",
        statement_config: "StatementConfig | None" = None,
        driver_features: "dict[str, Any] | None" = None,
    ) -> None:
        if statement_config is None:
            statement_config = default_statement_config.replace(
                enable_caching=get_cache_config().compiled_cache_enabled
            )
        super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
        self._data_dictionary: MssqlPythonSyncDataDictionary | None = None
@property def data_dictionary(self) -> "MssqlPythonSyncDataDictionary": if self._data_dictionary is None: self._data_dictionary = MssqlPythonSyncDataDictionary() return self._data_dictionary def dispatch_execute(self, cursor: "MssqlPythonRawCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) _execute_cursor(cursor, sql, prepared_parameters) if statement.returns_rows(): fetched = cursor.fetchall() column_names = [desc[0] for desc in (cursor.description or [])] return self.create_execution_result( cursor, selected_data=fetched, column_names=column_names, data_row_count=len(fetched), is_select_result=True, row_format="tuple", ) return self.create_execution_result(cursor, rowcount_override=_cursor_rowcount(cursor))