Edit on GitHub

sqlglot.optimizer.pushdown_projections

  1from __future__ import annotations
  2
  3import typing as t
  4from collections import defaultdict
  5
  6from sqlglot import alias, exp
  7from sqlglot.optimizer.qualify_columns import Resolver
  8from sqlglot.optimizer.scope import Scope, traverse_scope
  9from sqlglot.schema import ensure_schema
 10from sqlglot.errors import OptimizeError
 11from sqlglot.helper import seq_get
 12
 13if t.TYPE_CHECKING:
 14    from sqlglot._typing import E
 15    from sqlglot.schema import Schema
 16    from sqlglot.dialects.dialect import DialectType
 17
 18# Sentinel value that means an outer query selecting ALL columns
 19SELECT_ALL = object()
 20
 21
 22# Selection to use if selection list is empty
 23def default_selection(is_agg: bool) -> exp.Alias:
 24    return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
 25
 26
 27def pushdown_projections(
 28    expression: E,
 29    schema: t.Optional[t.Dict | Schema] = None,
 30    remove_unused_selections: bool = True,
 31    dialect: DialectType = None,
 32) -> E:
 33    """
 34    Rewrite sqlglot AST to remove unused columns projections.
 35
 36    Example:
 37        >>> import sqlglot
 38        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
 39        >>> expression = sqlglot.parse_one(sql)
 40        >>> pushdown_projections(expression).sql()
 41        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
 42
 43    Args:
 44        expression (sqlglot.Expression): expression to optimize
 45        remove_unused_selections (bool): remove selects that are unused
 46    Returns:
 47        sqlglot.Expression: optimized expression
 48    """
 49    # Map of Scope to all columns being selected by outer queries.
 50    schema = ensure_schema(schema, dialect=dialect)
 51    source_column_alias_count: t.Dict[exp.Expression | Scope, int] = {}
 52    referenced_columns: t.DefaultDict[Scope, t.Set[str | object]] = defaultdict(set)
 53
 54    # We build the scope tree (which is traversed in DFS postorder), then iterate
 55    # over the result in reverse order. This should ensure that the set of selected
 56    # columns for a particular scope are completely build by the time we get to it.
 57    for scope in reversed(traverse_scope(expression)):
 58        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
 59        alias_count = source_column_alias_count.get(scope, 0)
 60
 61        # We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
 62        if scope.expression.args.get("distinct"):
 63            parent_selections = {SELECT_ALL}
 64
 65        if isinstance(scope.expression, exp.SetOperation):
 66            set_op = scope.expression
 67            if not (set_op.kind or set_op.side):
 68                # Do not optimize this set operation if it's using the BigQuery specific
 69                # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation
 70                left, right = scope.union_scopes
 71                if len(left.expression.selects) != len(right.expression.selects):
 72                    scope_sql = scope.expression.sql()
 73                    raise OptimizeError(
 74                        f"Invalid set operation due to column mismatch: {scope_sql}."
 75                    )
 76
 77                referenced_columns[left] = parent_selections
 78
 79                if any(select.is_star for select in right.expression.selects):
 80                    referenced_columns[right] = parent_selections
 81                elif not any(select.is_star for select in left.expression.selects):
 82                    if scope.expression.args.get("by_name"):
 83                        referenced_columns[right] = referenced_columns[left]
 84                    else:
 85                        referenced_columns[right] = {
 86                            right.expression.selects[i].alias_or_name
 87                            for i, select in enumerate(left.expression.selects)
 88                            if SELECT_ALL in parent_selections
 89                            or select.alias_or_name in parent_selections
 90                        }
 91
 92        if isinstance(scope.expression, exp.Select):
 93            if remove_unused_selections:
 94                _remove_unused_selections(scope, parent_selections, schema, alias_count)
 95
 96            if scope.expression.is_star:
 97                continue
 98
 99            # Group columns by source name
100            selects = defaultdict(set)
101            for col in scope.columns:
102                table_name = col.table
103                col_name = col.name
104                selects[table_name].add(col_name)
105
106            # Push the selected columns down to the next scope
107            for name, (node, source) in scope.selected_sources.items():
108                if isinstance(source, Scope):
109                    select = seq_get(source.expression.selects, 0)
110
111                    if scope.pivots or isinstance(select, exp.QueryTransform):
112                        columns = {SELECT_ALL}
113                    else:
114                        columns = selects.get(name) or set()
115
116                    referenced_columns[source].update(columns)
117
118                column_aliases = node.alias_column_names
119                if column_aliases:
120                    source_column_alias_count[source] = len(column_aliases)
121
122    return expression
123
124
125def _remove_unused_selections(scope, parent_selections, schema, alias_count):
126    order = scope.expression.args.get("order")
127
128    if order:
129        # Assume columns without a qualified table are references to output columns
130        order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
131    else:
132        order_refs = set()
133
134    new_selections = []
135    removed = False
136    star = False
137    is_agg = False
138
139    select_all = SELECT_ALL in parent_selections
140
141    for selection in scope.expression.selects:
142        name = selection.alias_or_name
143
144        if select_all or name in parent_selections or name in order_refs or alias_count > 0:
145            new_selections.append(selection)
146            alias_count -= 1
147        else:
148            if selection.is_star:
149                star = True
150            removed = True
151
152        if not is_agg and selection.find(exp.AggFunc):
153            is_agg = True
154
155    if star:
156        resolver = Resolver(scope, schema)
157        names = {s.alias_or_name for s in new_selections}
158
159        for name in sorted(parent_selections):
160            if name not in names:
161                new_selections.append(
162                    alias(exp.column(name, table=resolver.get_table(name)), name, copy=False)
163                )
164
165    # If there are no remaining selections, just select a single constant
166    if not new_selections:
167        new_selections.append(default_selection(is_agg))
168
169    scope.expression.select(*new_selections, append=False, copy=False)
170
171    if removed:
172        scope.clear_cache()
SELECT_ALL = <object object>
def default_selection(is_agg: bool) -> sqlglot.expressions.Alias:
24def default_selection(is_agg: bool) -> exp.Alias:
25    return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
def pushdown_projections( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, remove_unused_selections: bool = True, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> ~E:
 28def pushdown_projections(
 29    expression: E,
 30    schema: t.Optional[t.Dict | Schema] = None,
 31    remove_unused_selections: bool = True,
 32    dialect: DialectType = None,
 33) -> E:
 34    """
 35    Rewrite sqlglot AST to remove unused columns projections.
 36
 37    Example:
 38        >>> import sqlglot
 39        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
 40        >>> expression = sqlglot.parse_one(sql)
 41        >>> pushdown_projections(expression).sql()
 42        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
 43
 44    Args:
 45        expression (sqlglot.Expression): expression to optimize
 46        remove_unused_selections (bool): remove selects that are unused
 47    Returns:
 48        sqlglot.Expression: optimized expression
 49    """
 50    # Map of Scope to all columns being selected by outer queries.
 51    schema = ensure_schema(schema, dialect=dialect)
 52    source_column_alias_count: t.Dict[exp.Expression | Scope, int] = {}
 53    referenced_columns: t.DefaultDict[Scope, t.Set[str | object]] = defaultdict(set)
 54
 55    # We build the scope tree (which is traversed in DFS postorder), then iterate
 56    # over the result in reverse order. This should ensure that the set of selected
 57    # columns for a particular scope are completely build by the time we get to it.
 58    for scope in reversed(traverse_scope(expression)):
 59        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
 60        alias_count = source_column_alias_count.get(scope, 0)
 61
 62        # We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
 63        if scope.expression.args.get("distinct"):
 64            parent_selections = {SELECT_ALL}
 65
 66        if isinstance(scope.expression, exp.SetOperation):
 67            set_op = scope.expression
 68            if not (set_op.kind or set_op.side):
 69                # Do not optimize this set operation if it's using the BigQuery specific
 70                # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation
 71                left, right = scope.union_scopes
 72                if len(left.expression.selects) != len(right.expression.selects):
 73                    scope_sql = scope.expression.sql()
 74                    raise OptimizeError(
 75                        f"Invalid set operation due to column mismatch: {scope_sql}."
 76                    )
 77
 78                referenced_columns[left] = parent_selections
 79
 80                if any(select.is_star for select in right.expression.selects):
 81                    referenced_columns[right] = parent_selections
 82                elif not any(select.is_star for select in left.expression.selects):
 83                    if scope.expression.args.get("by_name"):
 84                        referenced_columns[right] = referenced_columns[left]
 85                    else:
 86                        referenced_columns[right] = {
 87                            right.expression.selects[i].alias_or_name
 88                            for i, select in enumerate(left.expression.selects)
 89                            if SELECT_ALL in parent_selections
 90                            or select.alias_or_name in parent_selections
 91                        }
 92
 93        if isinstance(scope.expression, exp.Select):
 94            if remove_unused_selections:
 95                _remove_unused_selections(scope, parent_selections, schema, alias_count)
 96
 97            if scope.expression.is_star:
 98                continue
 99
100            # Group columns by source name
101            selects = defaultdict(set)
102            for col in scope.columns:
103                table_name = col.table
104                col_name = col.name
105                selects[table_name].add(col_name)
106
107            # Push the selected columns down to the next scope
108            for name, (node, source) in scope.selected_sources.items():
109                if isinstance(source, Scope):
110                    select = seq_get(source.expression.selects, 0)
111
112                    if scope.pivots or isinstance(select, exp.QueryTransform):
113                        columns = {SELECT_ALL}
114                    else:
115                        columns = selects.get(name) or set()
116
117                    referenced_columns[source].update(columns)
118
119                column_aliases = node.alias_column_names
120                if column_aliases:
121                    source_column_alias_count[source] = len(column_aliases)
122
123    return expression

Rewrite sqlglot AST to remove unused columns projections.

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_projections(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
  • remove_unused_selections (bool): remove selects that are unused
Returns:

sqlglot.Expression: optimized expression