Source code for sqlspec.migrations.utils

"""Utility functions for SQLSpec migrations."""

import importlib
import inspect
import os
import subprocess
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast

from sqlspec.migrations.templates import MigrationTemplateSettings, TemplateValidationError, build_template_settings
from sqlspec.utils.logging import get_logger
from sqlspec.utils.text import slugify

if TYPE_CHECKING:
    from collections.abc import Callable, Mapping

    from sqlspec.config import DatabaseConfigProtocol
    from sqlspec.driver import AsyncDriverAdapterBase

__all__ = ("create_migration_file", "drop_all", "get_author", "resolve_default_schema", "resolve_tracker_schema")

logger = get_logger(__name__)


def resolve_default_schema(migration_config: "Mapping[str, Any] | None") -> str | None:
    """Resolve the configured default migration schema.

    Args:
        migration_config: Migration configuration mapping.

    Returns:
        Default schema string when configured, otherwise ``None``.
    """
    if not migration_config:
        return None
    default_schema = migration_config.get("default_schema")
    if isinstance(default_schema, str) and default_schema:
        return default_schema
    return None


def resolve_tracker_schema(migration_config: "Mapping[str, Any] | None") -> str | None:
    """Resolve the schema for the migration tracking table.

    Args:
        migration_config: Migration configuration mapping.

    Returns:
        Explicit tracker schema, default migration schema, or None.
    """
    if not migration_config:
        return None
    version_table_schema = migration_config.get("version_table_schema")
    if isinstance(version_table_schema, str) and version_table_schema:
        return version_table_schema
    return resolve_default_schema(migration_config)


def create_migration_file(
    migrations_dir: Path,
    version: str,
    message: str,
    file_type: str | None = None,
    *,
    config: "DatabaseConfigProtocol[Any, Any, Any] | None" = None,
    template_settings: "MigrationTemplateSettings | None" = None,
) -> Path:
    """Create a new migration file from template."""

    migration_config = cast("dict[str, Any]", config.migration_config) if config is not None else {}
    settings = template_settings or build_template_settings(migration_config)
    author = get_author(migration_config.get("author"), config=config)
    safe_message = _slugify_message(message)
    file_format = settings.resolve_format(file_type)
    extension = "py" if file_format == "py" else "sql"
    filename = f"{version}_{safe_message or 'migration'}.{extension}"
    file_path = migrations_dir / filename
    context = _build_template_context(
        settings=settings,
        version=version,
        message=message,
        author=author,
        adapter=_resolve_adapter_name(config),
        project_slug=_derive_project_slug(config),
        safe_message=safe_message,
    )
    renderer = settings.profile.python.render if file_format == "py" else settings.profile.sql.render
    content = renderer(context)
    file_path.write_text(content, encoding="utf-8")
    return file_path
def get_author( author_config: Any | None = None, *, config: "DatabaseConfigProtocol[Any, Any, Any] | None" = None ) -> str: """Resolve author metadata for migration templates.""" if isinstance(author_config, str): token = author_config.strip() if not token: return _resolve_git_author() lowered = token.lower() if lowered == "git": return _resolve_git_author() if lowered == "system": return _get_system_username() if lowered.startswith("env:"): env_var = token.split(":", 1)[1].strip() if not env_var: msg = "Environment author token requires a variable name" raise TemplateValidationError(msg) return _resolve_author_from_env(env_var) if lowered.startswith("callable:"): import_path = token.split(":", 1)[1].strip() if not import_path: msg = "Callable author token requires an import path" raise TemplateValidationError(msg) return _resolve_author_callable(import_path, config) if ":" in token and " " not in token: return _resolve_author_callable(token, config) return token if isinstance(author_config, dict): mode = str(author_config.get("mode") or "static").lower() value = author_config.get("value") if mode == "static": if not isinstance(value, str) or not value.strip(): msg = "Static author value must be a non-empty string" raise TemplateValidationError(msg) return value.strip() if mode == "env": if not isinstance(value, str) or not value.strip(): msg = "Environment author mode requires an environment variable name" raise TemplateValidationError(msg) return _resolve_author_from_env(value.strip()) if mode == "callable": if not isinstance(value, str) or not value.strip(): msg = "Callable author mode requires an import path" raise TemplateValidationError(msg) return _resolve_author_callable(value.strip(), config) if mode == "system": return _get_system_username() if mode == "git": return _resolve_git_author() msg = f"Unsupported author mode '{mode}'" raise TemplateValidationError(msg) return _resolve_git_author()