Edit on GitHub

sqlglot.optimizer.pushdown_projections

  1from collections import defaultdict
  2
  3from sqlglot import alias, exp
  4from sqlglot.optimizer.qualify_columns import Resolver
  5from sqlglot.optimizer.scope import Scope, traverse_scope
  6from sqlglot.schema import ensure_schema
  7from sqlglot.errors import OptimizeError
  8from sqlglot.helper import seq_get
  9
 10# Sentinel value that means an outer query selecting ALL columns
 11SELECT_ALL = object()
 12
 13
 14# Selection to use if selection list is empty
 15def default_selection(is_agg: bool) -> exp.Alias:
 16    return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
 17
 18
 19def pushdown_projections(expression, schema=None, remove_unused_selections=True):
 20    """
 21    Rewrite sqlglot AST to remove unused columns projections.
 22
 23    Example:
 24        >>> import sqlglot
 25        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
 26        >>> expression = sqlglot.parse_one(sql)
 27        >>> pushdown_projections(expression).sql()
 28        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
 29
 30    Args:
 31        expression (sqlglot.Expression): expression to optimize
 32        remove_unused_selections (bool): remove selects that are unused
 33    Returns:
 34        sqlglot.Expression: optimized expression
 35    """
 36    # Map of Scope to all columns being selected by outer queries.
 37    schema = ensure_schema(schema)
 38    source_column_alias_count = {}
 39    referenced_columns = defaultdict(set)
 40
 41    # We build the scope tree (which is traversed in DFS postorder), then iterate
 42    # over the result in reverse order. This should ensure that the set of selected
 43    # columns for a particular scope are completely build by the time we get to it.
 44    for scope in reversed(traverse_scope(expression)):
 45        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
 46        alias_count = source_column_alias_count.get(scope, 0)
 47
 48        # We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
 49        if scope.expression.args.get("distinct"):
 50            parent_selections = {SELECT_ALL}
 51
 52        if isinstance(scope.expression, exp.SetOperation):
 53            set_op = scope.expression
 54            if not (set_op.kind or set_op.side):
 55                # Do not optimize this set operation if it's using the BigQuery specific
 56                # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation
 57                left, right = scope.union_scopes
 58                if len(left.expression.selects) != len(right.expression.selects):
 59                    scope_sql = scope.expression.sql()
 60                    raise OptimizeError(
 61                        f"Invalid set operation due to column mismatch: {scope_sql}."
 62                    )
 63
 64                referenced_columns[left] = parent_selections
 65
 66                if any(select.is_star for select in right.expression.selects):
 67                    referenced_columns[right] = parent_selections
 68                elif not any(select.is_star for select in left.expression.selects):
 69                    if scope.expression.args.get("by_name"):
 70                        referenced_columns[right] = referenced_columns[left]
 71                    else:
 72                        referenced_columns[right] = [
 73                            right.expression.selects[i].alias_or_name
 74                            for i, select in enumerate(left.expression.selects)
 75                            if SELECT_ALL in parent_selections
 76                            or select.alias_or_name in parent_selections
 77                        ]
 78
 79        if isinstance(scope.expression, exp.Select):
 80            if remove_unused_selections:
 81                _remove_unused_selections(scope, parent_selections, schema, alias_count)
 82
 83            if scope.expression.is_star:
 84                continue
 85
 86            # Group columns by source name
 87            selects = defaultdict(set)
 88            for col in scope.columns:
 89                table_name = col.table
 90                col_name = col.name
 91                selects[table_name].add(col_name)
 92
 93            # Push the selected columns down to the next scope
 94            for name, (node, source) in scope.selected_sources.items():
 95                if isinstance(source, Scope):
 96                    select = seq_get(source.expression.selects, 0)
 97
 98                    if scope.pivots or isinstance(select, exp.QueryTransform):
 99                        columns = {SELECT_ALL}
100                    else:
101                        columns = selects.get(name) or set()
102
103                    referenced_columns[source].update(columns)
104
105                column_aliases = node.alias_column_names
106                if column_aliases:
107                    source_column_alias_count[source] = len(column_aliases)
108
109    return expression
110
111
112def _remove_unused_selections(scope, parent_selections, schema, alias_count):
113    order = scope.expression.args.get("order")
114
115    if order:
116        # Assume columns without a qualified table are references to output columns
117        order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
118    else:
119        order_refs = set()
120
121    new_selections = []
122    removed = False
123    star = False
124    is_agg = False
125
126    select_all = SELECT_ALL in parent_selections
127
128    for selection in scope.expression.selects:
129        name = selection.alias_or_name
130
131        if select_all or name in parent_selections or name in order_refs or alias_count > 0:
132            new_selections.append(selection)
133            alias_count -= 1
134        else:
135            if selection.is_star:
136                star = True
137            removed = True
138
139        if not is_agg and selection.find(exp.AggFunc):
140            is_agg = True
141
142    if star:
143        resolver = Resolver(scope, schema)
144        names = {s.alias_or_name for s in new_selections}
145
146        for name in sorted(parent_selections):
147            if name not in names:
148                new_selections.append(
149                    alias(exp.column(name, table=resolver.get_table(name)), name, copy=False)
150                )
151
152    # If there are no remaining selections, just select a single constant
153    if not new_selections:
154        new_selections.append(default_selection(is_agg))
155
156    scope.expression.select(*new_selections, append=False, copy=False)
157
158    if removed:
159        scope.clear_cache()
SELECT_ALL = <object object>
def default_selection(is_agg: bool) -> sqlglot.expressions.Alias:
16def default_selection(is_agg: bool) -> exp.Alias:
17    return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
 20def pushdown_projections(expression, schema=None, remove_unused_selections=True):
 21    """
 22    Rewrite sqlglot AST to remove unused columns projections.
 23
 24    Example:
 25        >>> import sqlglot
 26        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
 27        >>> expression = sqlglot.parse_one(sql)
 28        >>> pushdown_projections(expression).sql()
 29        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
 30
 31    Args:
 32        expression (sqlglot.Expression): expression to optimize
 33        remove_unused_selections (bool): remove selects that are unused
 34    Returns:
 35        sqlglot.Expression: optimized expression
 36    """
 37    # Map of Scope to all columns being selected by outer queries.
 38    schema = ensure_schema(schema)
 39    source_column_alias_count = {}
 40    referenced_columns = defaultdict(set)
 41
 42    # We build the scope tree (which is traversed in DFS postorder), then iterate
 43    # over the result in reverse order. This should ensure that the set of selected
 44    # columns for a particular scope are completely build by the time we get to it.
 45    for scope in reversed(traverse_scope(expression)):
 46        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
 47        alias_count = source_column_alias_count.get(scope, 0)
 48
 49        # We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
 50        if scope.expression.args.get("distinct"):
 51            parent_selections = {SELECT_ALL}
 52
 53        if isinstance(scope.expression, exp.SetOperation):
 54            set_op = scope.expression
 55            if not (set_op.kind or set_op.side):
 56                # Do not optimize this set operation if it's using the BigQuery specific
 57                # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation
 58                left, right = scope.union_scopes
 59                if len(left.expression.selects) != len(right.expression.selects):
 60                    scope_sql = scope.expression.sql()
 61                    raise OptimizeError(
 62                        f"Invalid set operation due to column mismatch: {scope_sql}."
 63                    )
 64
 65                referenced_columns[left] = parent_selections
 66
 67                if any(select.is_star for select in right.expression.selects):
 68                    referenced_columns[right] = parent_selections
 69                elif not any(select.is_star for select in left.expression.selects):
 70                    if scope.expression.args.get("by_name"):
 71                        referenced_columns[right] = referenced_columns[left]
 72                    else:
 73                        referenced_columns[right] = [
 74                            right.expression.selects[i].alias_or_name
 75                            for i, select in enumerate(left.expression.selects)
 76                            if SELECT_ALL in parent_selections
 77                            or select.alias_or_name in parent_selections
 78                        ]
 79
 80        if isinstance(scope.expression, exp.Select):
 81            if remove_unused_selections:
 82                _remove_unused_selections(scope, parent_selections, schema, alias_count)
 83
 84            if scope.expression.is_star:
 85                continue
 86
 87            # Group columns by source name
 88            selects = defaultdict(set)
 89            for col in scope.columns:
 90                table_name = col.table
 91                col_name = col.name
 92                selects[table_name].add(col_name)
 93
 94            # Push the selected columns down to the next scope
 95            for name, (node, source) in scope.selected_sources.items():
 96                if isinstance(source, Scope):
 97                    select = seq_get(source.expression.selects, 0)
 98
 99                    if scope.pivots or isinstance(select, exp.QueryTransform):
100                        columns = {SELECT_ALL}
101                    else:
102                        columns = selects.get(name) or set()
103
104                    referenced_columns[source].update(columns)
105
106                column_aliases = node.alias_column_names
107                if column_aliases:
108                    source_column_alias_count[source] = len(column_aliases)
109
110    return expression

Rewrite sqlglot AST to remove unused columns projections.

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_projections(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
  • remove_unused_selections (bool): remove selects that are unused
Returns:

sqlglot.Expression: optimized expression