"""mssql-python sync and async drivers."""
import asyncio
from typing import TYPE_CHECKING, Any, TypedDict, cast
from typing_extensions import NotRequired
from sqlspec.adapters.mssql_python._typing import (
MSSQL_PYTHON_MODULE,
MssqlPythonAsyncCursor,
MssqlPythonAsyncSessionContext,
MssqlPythonConnection,
MssqlPythonCursor,
MssqlPythonRawCursor,
MssqlPythonSessionContext,
)
from sqlspec.adapters.mssql_python.core import create_mapped_exception, default_statement_config, driver_profile
from sqlspec.adapters.mssql_python.data_dictionary import MssqlPythonAsyncDataDictionary, MssqlPythonSyncDataDictionary
from sqlspec.core import (
build_arrow_result_from_reader,
build_arrow_result_from_table,
get_cache_config,
register_driver_profile,
)
from sqlspec.driver import (
AsyncDriverAdapterBase,
BaseAsyncExceptionHandler,
BaseSyncExceptionHandler,
SyncDriverAdapterBase,
)
from sqlspec.exceptions import SQLSpecError
from sqlspec.utils.arrow_helpers import arrow_reader_with_deferred_close
from sqlspec.utils.logging import get_logger
from sqlspec.utils.module_loader import ensure_pyarrow
if TYPE_CHECKING:
from collections.abc import Iterable
from sqlspec.builder import QueryBuilder
from sqlspec.core import SQL, ArrowResult, Statement, StatementConfig, StatementFilter
from sqlspec.driver import ExecutionResult
from sqlspec.typing import ArrowRecordBatchReader, ArrowReturnFormat, StatementParameters
__all__ = (
"MssqlPythonAsyncCursor",
"MssqlPythonAsyncDriver",
"MssqlPythonAsyncExceptionHandler",
"MssqlPythonAsyncSessionContext",
"MssqlPythonBulkCopyResult",
"MssqlPythonCursor",
"MssqlPythonDriver",
"MssqlPythonExceptionHandler",
"MssqlPythonSessionContext",
)
logger = get_logger("sqlspec.adapters.mssql_python")
_MSSQL_ERROR = cast("type[BaseException]", getattr(MSSQL_PYTHON_MODULE, "Error", Exception))
class MssqlPythonBulkCopyResult(TypedDict):
"""BulkCopy statistics returned by mssql-python."""
rows_copied: int
batch_count: NotRequired[int]
elapsed_time: NotRequired[float]
class MssqlPythonExceptionHandler(BaseSyncExceptionHandler):
"""Sync context manager handling mssql-python exceptions."""
__slots__ = ()
def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
if exc_type is None:
return False
if isinstance(exc_val, _MSSQL_ERROR):
self.pending_exception = create_mapped_exception(cast("Exception", exc_val), logger=logger)
return True
return False
class MssqlPythonAsyncExceptionHandler(BaseAsyncExceptionHandler):
"""Async context manager handling mssql-python exceptions."""
__slots__ = ()
def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
if exc_type is None:
return False
if isinstance(exc_val, _MSSQL_ERROR):
self.pending_exception = create_mapped_exception(cast("Exception", exc_val), logger=logger)
return True
return False
class MssqlPythonDriver(SyncDriverAdapterBase):
"""mssql-python sync driver."""
__slots__ = ("_data_dictionary",)
dialect = "tsql"
def __init__(
self,
connection: "MssqlPythonConnection",
statement_config: "StatementConfig | None" = None,
driver_features: "dict[str, Any] | None" = None,
) -> None:
if statement_config is None:
statement_config = default_statement_config.replace(
enable_caching=get_cache_config().compiled_cache_enabled
)
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
self._data_dictionary: MssqlPythonSyncDataDictionary | None = None