Edit on GitHub

sqlglot.optimizer.pushdown_projections

  1from __future__ import annotations
  2
  3import typing as t
  4from collections import defaultdict
  5
  6from sqlglot import alias, exp
  7from sqlglot.optimizer.qualify_columns import Resolver
  8from sqlglot.optimizer.scope import Scope, traverse_scope
  9from sqlglot.schema import ensure_schema
 10from sqlglot.errors import OptimizeError
 11from sqlglot.helper import seq_get
 12
 13if t.TYPE_CHECKING:
 14    from sqlglot._typing import E
 15    from sqlglot.schema import Schema
 16    from sqlglot.dialects.dialect import DialectType
 17
 18# Sentinel value that means an outer query selecting ALL columns
 19SELECT_ALL = object()
 20
 21
 22# Selection to use if selection list is empty
 23def default_selection(is_agg: bool) -> exp.Alias:
 24    return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_").assert_is(exp.Alias)
 25
 26
 27def pushdown_projections(
 28    expression: E,
 29    schema: dict[str, object] | Schema | None = None,
 30    remove_unused_selections: bool = True,
 31    dialect: DialectType = None,
 32) -> E:
 33    """
 34    Rewrite sqlglot AST to remove unused columns projections.
 35
 36    Example:
 37        >>> import sqlglot
 38        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
 39        >>> expression = sqlglot.parse_one(sql)
 40        >>> pushdown_projections(expression).sql()
 41        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
 42
 43    Args:
 44        expression (sqlglot.Expr): expression to optimize
 45        remove_unused_selections (bool): remove selects that are unused
 46    Returns:
 47        sqlglot.Expr: optimized expression
 48    """
 49    # Map of Scope to all columns being selected by outer queries.
 50    schema = ensure_schema(schema, dialect=dialect)
 51    source_column_alias_count: dict[exp.Expr | Scope, int] = {}
 52    referenced_columns: defaultdict[Scope, set[str | object]] = defaultdict(set)
 53
 54    # We build the scope tree (which is traversed in DFS postorder), then iterate
 55    # over the result in reverse order. This should ensure that the set of selected
 56    # columns for a particular scope are completely build by the time we get to it.
 57    for scope in reversed(traverse_scope(expression)):
 58        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
 59        alias_count = source_column_alias_count.get(scope, 0)
 60
 61        # We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
 62        if scope.expression.args.get("distinct"):
 63            parent_selections = {SELECT_ALL}
 64
 65        if isinstance(scope.expression, exp.SetOperation):
 66            set_op = scope.expression
 67            if set_op.kind or set_op.side:
 68                # Do not optimize this set operation if it's using the BigQuery specific
 69                # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation
 70                continue
 71
 72            left, right = scope.union_scopes
 73            le = left.expression
 74            re = right.expression
 75
 76            if not (isinstance(le, exp.Selectable) and isinstance(re, exp.Selectable)):
 77                continue
 78
 79            if len(le.selects) != len(re.selects):
 80                scope_sql = scope.expression.sql(dialect=dialect)
 81                raise OptimizeError(f"Invalid set operation due to column mismatch: {scope_sql}.")
 82
 83            referenced_columns[left] = parent_selections
 84
 85            if re.is_star:
 86                referenced_columns[right] = parent_selections
 87            elif not le.is_star:
 88                if scope.expression.args.get("by_name"):
 89                    referenced_columns[right] = referenced_columns[left]
 90                else:
 91                    referenced_columns[right] = {
 92                        re.selects[i].alias_or_name
 93                        for i, select in enumerate(le.selects)
 94                        if SELECT_ALL in parent_selections
 95                        or select.alias_or_name in parent_selections
 96                    }
 97
 98        if isinstance(scope.expression, exp.Select):
 99            if remove_unused_selections:
100                _remove_unused_selections(scope, parent_selections, schema, alias_count)
101
102            if scope.expression.is_star:
103                continue
104
105            # Group columns by source name
106            selects: dict[str, set[object]] = defaultdict(set)
107            for col in scope.columns:
108                table_name = col.table
109                col_name = col.name
110                selects[table_name].add(col_name)
111
112            # Push the selected columns down to the next scope
113            for name, (node, source) in scope.selected_sources.items():
114                if isinstance(source, Scope) and isinstance(source.expression, exp.Selectable):
115                    select = seq_get(source.expression.selects, 0)
116
117                    if scope.pivots or isinstance(select, exp.QueryTransform):
118                        columns: set[object] = {SELECT_ALL}
119                    else:
120                        columns = selects.get(name) or set()
121
122                    referenced_columns[source].update(columns)
123
124                column_aliases = node.alias_column_names
125                if column_aliases:
126                    source_column_alias_count[source] = len(column_aliases)
127
128    return expression
129
130
131def _remove_unused_selections(scope, parent_selections, schema, alias_count):
132    order = scope.expression.args.get("order")
133
134    if order:
135        # Assume columns without a qualified table are references to output columns
136        order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
137    else:
138        order_refs = set()
139
140    new_selections = []
141    removed = False
142    star = False
143    is_agg = False
144
145    select_all = SELECT_ALL in parent_selections
146
147    for selection in scope.expression.selects:
148        name = selection.alias_or_name
149
150        if select_all or name in parent_selections or name in order_refs or alias_count > 0:
151            new_selections.append(selection)
152            alias_count -= 1
153        else:
154            if selection.is_star:
155                star = True
156            removed = True
157
158        if not is_agg and selection.find(exp.AggFunc):
159            is_agg = True
160
161    if star:
162        resolver = Resolver(scope, schema)
163        names = {s.alias_or_name for s in new_selections}
164
165        for name in sorted(parent_selections):
166            if name not in names:
167                new_selections.append(
168                    alias(exp.column(name, table=resolver.get_table(name)), name, copy=False)
169                )
170
171    # If there are no remaining selections, just select a single constant
172    if not new_selections:
173        new_selections.append(default_selection(is_agg))
174
175    scope.expression.select(*new_selections, append=False, copy=False)
176
177    if removed:
178        scope.clear_cache()
SELECT_ALL = <object object>
def default_selection(is_agg: bool) -> sqlglot.expressions.core.Alias:
24def default_selection(is_agg: bool) -> exp.Alias:
25    return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_").assert_is(exp.Alias)
def pushdown_projections( expression: ~E, schema: dict[str, object] | sqlglot.schema.Schema | None = None, remove_unused_selections: bool = True, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None) -> ~E:
 28def pushdown_projections(
 29    expression: E,
 30    schema: dict[str, object] | Schema | None = None,
 31    remove_unused_selections: bool = True,
 32    dialect: DialectType = None,
 33) -> E:
 34    """
 35    Rewrite sqlglot AST to remove unused columns projections.
 36
 37    Example:
 38        >>> import sqlglot
 39        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
 40        >>> expression = sqlglot.parse_one(sql)
 41        >>> pushdown_projections(expression).sql()
 42        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
 43
 44    Args:
 45        expression (sqlglot.Expr): expression to optimize
 46        remove_unused_selections (bool): remove selects that are unused
 47    Returns:
 48        sqlglot.Expr: optimized expression
 49    """
 50    # Map of Scope to all columns being selected by outer queries.
 51    schema = ensure_schema(schema, dialect=dialect)
 52    source_column_alias_count: dict[exp.Expr | Scope, int] = {}
 53    referenced_columns: defaultdict[Scope, set[str | object]] = defaultdict(set)
 54
 55    # We build the scope tree (which is traversed in DFS postorder), then iterate
 56    # over the result in reverse order. This should ensure that the set of selected
 57    # columns for a particular scope are completely build by the time we get to it.
 58    for scope in reversed(traverse_scope(expression)):
 59        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
 60        alias_count = source_column_alias_count.get(scope, 0)
 61
 62        # We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
 63        if scope.expression.args.get("distinct"):
 64            parent_selections = {SELECT_ALL}
 65
 66        if isinstance(scope.expression, exp.SetOperation):
 67            set_op = scope.expression
 68            if set_op.kind or set_op.side:
 69                # Do not optimize this set operation if it's using the BigQuery specific
 70                # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation
 71                continue
 72
 73            left, right = scope.union_scopes
 74            le = left.expression
 75            re = right.expression
 76
 77            if not (isinstance(le, exp.Selectable) and isinstance(re, exp.Selectable)):
 78                continue
 79
 80            if len(le.selects) != len(re.selects):
 81                scope_sql = scope.expression.sql(dialect=dialect)
 82                raise OptimizeError(f"Invalid set operation due to column mismatch: {scope_sql}.")
 83
 84            referenced_columns[left] = parent_selections
 85
 86            if re.is_star:
 87                referenced_columns[right] = parent_selections
 88            elif not le.is_star:
 89                if scope.expression.args.get("by_name"):
 90                    referenced_columns[right] = referenced_columns[left]
 91                else:
 92                    referenced_columns[right] = {
 93                        re.selects[i].alias_or_name
 94                        for i, select in enumerate(le.selects)
 95                        if SELECT_ALL in parent_selections
 96                        or select.alias_or_name in parent_selections
 97                    }
 98
 99        if isinstance(scope.expression, exp.Select):
100            if remove_unused_selections:
101                _remove_unused_selections(scope, parent_selections, schema, alias_count)
102
103            if scope.expression.is_star:
104                continue
105
106            # Group columns by source name
107            selects: dict[str, set[object]] = defaultdict(set)
108            for col in scope.columns:
109                table_name = col.table
110                col_name = col.name
111                selects[table_name].add(col_name)
112
113            # Push the selected columns down to the next scope
114            for name, (node, source) in scope.selected_sources.items():
115                if isinstance(source, Scope) and isinstance(source.expression, exp.Selectable):
116                    select = seq_get(source.expression.selects, 0)
117
118                    if scope.pivots or isinstance(select, exp.QueryTransform):
119                        columns: set[object] = {SELECT_ALL}
120                    else:
121                        columns = selects.get(name) or set()
122
123                    referenced_columns[source].update(columns)
124
125                column_aliases = node.alias_column_names
126                if column_aliases:
127                    source_column_alias_count[source] = len(column_aliases)
128
129    return expression

Rewrite sqlglot AST to remove unused columns projections.

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_projections(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
  • expression (sqlglot.Expr): expression to optimize
  • remove_unused_selections (bool): remove selects that are unused
Returns:

sqlglot.Expr: optimized expression