Source code for sqlspec.migrations.templates

"""Migration template rendering and configuration utilities."""

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

from sqlspec.exceptions import SQLSpecError

if TYPE_CHECKING:
    from collections.abc import Mapping

__all__ = (
    "MigrationTemplateProfile",
    "MigrationTemplateSettings",
    "PythonTemplateDefinition",
    "SQLTemplateDefinition",
    "TemplateDescriptionHints",
    "TemplateValidationError",
    "build_template_settings",
)


class TemplateValidationError(SQLSpecError):
    """Raised when a migration template definition is invalid."""


@dataclass(slots=True)
class TemplateDescriptionHints:
    """Hints for extracting descriptions from rendered templates."""

    sql_keys: "tuple[str, ...]" = ("Description",)
    python_keys: "tuple[str, ...]" = ("Description",)


@dataclass(slots=True)
class SQLTemplateDefinition:
    """SQL migration template fragments."""

    header: str
    metadata: "list[str]" = field(default_factory=list)
    body: str = ""
    description_keys: "tuple[str, ...]" = ("Description",)

    def render(self, context: "Mapping[str, str]") -> str:
        """Render the SQL template using the supplied context."""

        rendered_lines: list[str] = [self._format(self.header, context)]
        rendered_lines.extend(self._format(line, context) for line in self.metadata if line)
        rendered_lines.extend(("", self._format(self.body, context)))
        return "\n".join(_normalize_newlines(rendered_lines)).rstrip() + "\n"
def _format(self, template: str, context: "Mapping[str, str]") -> str: try: return template.format_map(context) except KeyError as exc: # pragma: no cover missing = str(exc).strip("'") msg = f"Missing template variable '{missing}' in SQL template" raise TemplateValidationError(msg) from exc except ValueError as exc: # pragma: no cover msg = f"Invalid SQL template fragment: {exc}" raise TemplateValidationError(msg) from exc