"""Utility functions for SQLSpec migrations."""
import importlib
import inspect
import os
import subprocess
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast
from sqlspec.migrations.templates import MigrationTemplateSettings, TemplateValidationError, build_template_settings
from sqlspec.utils.logging import get_logger
from sqlspec.utils.text import slugify
if TYPE_CHECKING:
from collections.abc import Callable, Mapping
from sqlspec.config import DatabaseConfigProtocol
from sqlspec.driver import AsyncDriverAdapterBase
__all__ = ("create_migration_file", "drop_all", "get_author", "resolve_default_schema", "resolve_tracker_schema")
logger = get_logger(__name__)
def resolve_default_schema(migration_config: "Mapping[str, Any] | None") -> str | None:
"""Resolve the configured default migration schema.
Args:
migration_config: Migration configuration mapping.
Returns:
Default schema string when configured, otherwise ``None``.
"""
if not migration_config:
return None
default_schema = migration_config.get("default_schema")
if isinstance(default_schema, str) and default_schema:
return default_schema
return None
def resolve_tracker_schema(migration_config: "Mapping[str, Any] | None") -> str | None:
"""Resolve the schema for the migration tracking table.
Args:
migration_config: Migration configuration mapping.
Returns:
Explicit tracker schema, default migration schema, or None.
"""
if not migration_config:
return None
version_table_schema = migration_config.get("version_table_schema")
if isinstance(version_table_schema, str) and version_table_schema:
return version_table_schema
return resolve_default_schema(migration_config)
def create_migration_file(
migrations_dir: Path,
version: str,
message: str,
file_type: str | None = None,
*,
config: "DatabaseConfigProtocol[Any, Any, Any] | None" = None,
template_settings: "MigrationTemplateSettings | None" = None,
) -> Path:
"""Create a new migration file from template."""
migration_config = cast("dict[str, Any]", config.migration_config) if config is not None else {}
settings = template_settings or build_template_settings(migration_config)
author = get_author(migration_config.get("author"), config=config)
safe_message = _slugify_message(message)
file_format = settings.resolve_format(file_type)
extension = "py" if file_format == "py" else "sql"
filename = f"{version}_{safe_message or 'migration'}.{extension}"
file_path = migrations_dir / filename
context = _build_template_context(
settings=settings,
version=version,
message=message,
author=author,
adapter=_resolve_adapter_name(config),
project_slug=_derive_project_slug(config),
safe_message=safe_message,
)
renderer = settings.profile.python.render if file_format == "py" else settings.profile.sql.render
content = renderer(context)
file_path.write_text(content, encoding="utf-8")
return file_path