-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathoptimizer.py
More file actions
114 lines (100 loc) · 4.2 KB
/
optimizer.py
File metadata and controls
114 lines (100 loc) · 4.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from __future__ import annotations
import inspect
import typing as t
from collections.abc import Sequence
from sqlglot import Schema, exp
from sqlglot.dialects.dialect import DialectType
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.merge_subqueries import merge_subqueries
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify import qualify
from sqlglot.optimizer.qualify_columns import quote_identifiers
from sqlglot.optimizer.simplify import simplify
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema
class OptimizerFn(t.Protocol):
"""Protocol for optimizer rules functions.
An optimizer rule:
- **Must** accept an `Expr` as the first argument
- Can take undefined `*args` and `**kwargs` afterwards
- **Must** return an `Expr`.
Note:
We use `typing.Protocol` here because this is not doable with `collections.abc.Callable`.
"""
def __call__(self, expression: exp.Expr, *args: t.Any, **kwargs: t.Any) -> exp.Expr: ...
RULES: tuple[OptimizerFn, ...] = (
qualify,
pushdown_projections,
normalize,
unnest_subqueries,
pushdown_predicates,
optimize_joins,
eliminate_subqueries,
merge_subqueries,
eliminate_joins,
eliminate_ctes,
quote_identifiers,
annotate_types,
canonicalize,
simplify,
)
def optimize(
expression: str | exp.Expr,
schema: dict[str, object] | Schema | None = None,
db: str | exp.Identifier | None = None,
catalog: str | exp.Identifier | None = None,
dialect: DialectType = None,
rules: Sequence[OptimizerFn] = RULES,
sql: str | None = None,
**kwargs: object,
) -> exp.Expr:
"""
Rewrite a sqlglot AST into an optimized form.
Args:
expression: expression to optimize
schema: database schema.
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
the following forms:
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
If no schema is provided then the default schema defined at `sqlgot.schema` will be used
db: specify the default database, as might be set by a `USE DATABASE db` statement
catalog: specify the default catalog, as might be set by a `USE CATALOG c` statement
dialect: The dialect to parse the sql string.
rules: sequence of optimizer rules to use.
Many of the rules require tables and columns to be qualified.
Do not remove `qualify` from the sequence of rules unless you know what you're doing!
sql: Original SQL string for error highlighting. If not provided, errors will not include
highlighting. Requires that the expression has position metadata from parsing.
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
Returns:
The optimized expression.
"""
schema = ensure_schema(schema, dialect=dialect)
possible_kwargs = {
"db": db,
"catalog": catalog,
"schema": schema,
"dialect": dialect,
"sql": sql,
"isolate_tables": True, # needed for other optimizations to perform well
"quote_identifiers": False,
**kwargs,
}
optimized = exp.maybe_parse(expression, dialect=dialect, copy=True)
for rule in rules:
# Find any additional rule parameters, beyond `expression`
rule_params = inspect.getfullargspec(rule).args
rule_kwargs = {
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
}
optimized = rule(optimized, **rule_kwargs)
return optimized