Edit on GitHub

sqlglot.optimizer.pushdown_projections

  1from collections import defaultdict
  2
  3from sqlglot import alias, exp
  4from sqlglot.optimizer.qualify_columns import Resolver
  5from sqlglot.optimizer.scope import Scope, traverse_scope
  6from sqlglot.schema import ensure_schema
  7from sqlglot.errors import OptimizeError
  8
  9# Sentinel value that means an outer query selecting ALL columns
 10SELECT_ALL = object()
 11
 12
 13# Selection to use if selection list is empty
 14def default_selection(is_agg: bool) -> exp.Alias:
 15    return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
 16
 17
 18def pushdown_projections(expression, schema=None, remove_unused_selections=True):
 19    """
 20    Rewrite sqlglot AST to remove unused columns projections.
 21
 22    Example:
 23        >>> import sqlglot
 24        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
 25        >>> expression = sqlglot.parse_one(sql)
 26        >>> pushdown_projections(expression).sql()
 27        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
 28
 29    Args:
 30        expression (sqlglot.Expression): expression to optimize
 31        remove_unused_selections (bool): remove selects that are unused
 32    Returns:
 33        sqlglot.Expression: optimized expression
 34    """
 35    # Map of Scope to all columns being selected by outer queries.
 36    schema = ensure_schema(schema)
 37    source_column_alias_count = {}
 38    referenced_columns = defaultdict(set)
 39
 40    # We build the scope tree (which is traversed in DFS postorder), then iterate
 41    # over the result in reverse order. This should ensure that the set of selected
 42    # columns for a particular scope are completely build by the time we get to it.
 43    for scope in reversed(traverse_scope(expression)):
 44        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
 45        alias_count = source_column_alias_count.get(scope, 0)
 46
 47        # We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
 48        if scope.expression.args.get("distinct"):
 49            parent_selections = {SELECT_ALL}
 50
 51        if isinstance(scope.expression, exp.SetOperation):
 52            set_op = scope.expression
 53            if not (set_op.kind or set_op.side):
 54                # Do not optimize this set operation if it's using the BigQuery specific
 55                # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation
 56                left, right = scope.union_scopes
 57                if len(left.expression.selects) != len(right.expression.selects):
 58                    scope_sql = scope.expression.sql()
 59                    raise OptimizeError(
 60                        f"Invalid set operation due to column mismatch: {scope_sql}."
 61                    )
 62
 63                referenced_columns[left] = parent_selections
 64
 65                if any(select.is_star for select in right.expression.selects):
 66                    referenced_columns[right] = parent_selections
 67                elif not any(select.is_star for select in left.expression.selects):
 68                    if scope.expression.args.get("by_name"):
 69                        referenced_columns[right] = referenced_columns[left]
 70                    else:
 71                        referenced_columns[right] = [
 72                            right.expression.selects[i].alias_or_name
 73                            for i, select in enumerate(left.expression.selects)
 74                            if SELECT_ALL in parent_selections
 75                            or select.alias_or_name in parent_selections
 76                        ]
 77
 78        if isinstance(scope.expression, exp.Select):
 79            if remove_unused_selections:
 80                _remove_unused_selections(scope, parent_selections, schema, alias_count)
 81
 82            if scope.expression.is_star:
 83                continue
 84
 85            # Group columns by source name
 86            selects = defaultdict(set)
 87            for col in scope.columns:
 88                table_name = col.table
 89                col_name = col.name
 90                selects[table_name].add(col_name)
 91
 92            # Push the selected columns down to the next scope
 93            for name, (node, source) in scope.selected_sources.items():
 94                if isinstance(source, Scope):
 95                    columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set()
 96                    referenced_columns[source].update(columns)
 97
 98                column_aliases = node.alias_column_names
 99                if column_aliases:
100                    source_column_alias_count[source] = len(column_aliases)
101
102    return expression
103
104
105def _remove_unused_selections(scope, parent_selections, schema, alias_count):
106    order = scope.expression.args.get("order")
107
108    if order:
109        # Assume columns without a qualified table are references to output columns
110        order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
111    else:
112        order_refs = set()
113
114    new_selections = []
115    removed = False
116    star = False
117    is_agg = False
118
119    select_all = SELECT_ALL in parent_selections
120
121    for selection in scope.expression.selects:
122        name = selection.alias_or_name
123
124        if select_all or name in parent_selections or name in order_refs or alias_count > 0:
125            new_selections.append(selection)
126            alias_count -= 1
127        else:
128            if selection.is_star:
129                star = True
130            removed = True
131
132        if not is_agg and selection.find(exp.AggFunc):
133            is_agg = True
134
135    if star:
136        resolver = Resolver(scope, schema)
137        names = {s.alias_or_name for s in new_selections}
138
139        for name in sorted(parent_selections):
140            if name not in names:
141                new_selections.append(
142                    alias(exp.column(name, table=resolver.get_table(name)), name, copy=False)
143                )
144
145    # If there are no remaining selections, just select a single constant
146    if not new_selections:
147        new_selections.append(default_selection(is_agg))
148
149    scope.expression.select(*new_selections, append=False, copy=False)
150
151    if removed:
152        scope.clear_cache()
SELECT_ALL = <object object>
def default_selection(is_agg: bool) -> sqlglot.expressions.Alias:
15def default_selection(is_agg: bool) -> exp.Alias:
16    return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
 19def pushdown_projections(expression, schema=None, remove_unused_selections=True):
 20    """
 21    Rewrite sqlglot AST to remove unused columns projections.
 22
 23    Example:
 24        >>> import sqlglot
 25        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
 26        >>> expression = sqlglot.parse_one(sql)
 27        >>> pushdown_projections(expression).sql()
 28        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
 29
 30    Args:
 31        expression (sqlglot.Expression): expression to optimize
 32        remove_unused_selections (bool): remove selects that are unused
 33    Returns:
 34        sqlglot.Expression: optimized expression
 35    """
 36    # Map of Scope to all columns being selected by outer queries.
 37    schema = ensure_schema(schema)
 38    source_column_alias_count = {}
 39    referenced_columns = defaultdict(set)
 40
 41    # We build the scope tree (which is traversed in DFS postorder), then iterate
 42    # over the result in reverse order. This should ensure that the set of selected
 43    # columns for a particular scope are completely build by the time we get to it.
 44    for scope in reversed(traverse_scope(expression)):
 45        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
 46        alias_count = source_column_alias_count.get(scope, 0)
 47
 48        # We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
 49        if scope.expression.args.get("distinct"):
 50            parent_selections = {SELECT_ALL}
 51
 52        if isinstance(scope.expression, exp.SetOperation):
 53            set_op = scope.expression
 54            if not (set_op.kind or set_op.side):
 55                # Do not optimize this set operation if it's using the BigQuery specific
 56                # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation
 57                left, right = scope.union_scopes
 58                if len(left.expression.selects) != len(right.expression.selects):
 59                    scope_sql = scope.expression.sql()
 60                    raise OptimizeError(
 61                        f"Invalid set operation due to column mismatch: {scope_sql}."
 62                    )
 63
 64                referenced_columns[left] = parent_selections
 65
 66                if any(select.is_star for select in right.expression.selects):
 67                    referenced_columns[right] = parent_selections
 68                elif not any(select.is_star for select in left.expression.selects):
 69                    if scope.expression.args.get("by_name"):
 70                        referenced_columns[right] = referenced_columns[left]
 71                    else:
 72                        referenced_columns[right] = [
 73                            right.expression.selects[i].alias_or_name
 74                            for i, select in enumerate(left.expression.selects)
 75                            if SELECT_ALL in parent_selections
 76                            or select.alias_or_name in parent_selections
 77                        ]
 78
 79        if isinstance(scope.expression, exp.Select):
 80            if remove_unused_selections:
 81                _remove_unused_selections(scope, parent_selections, schema, alias_count)
 82
 83            if scope.expression.is_star:
 84                continue
 85
 86            # Group columns by source name
 87            selects = defaultdict(set)
 88            for col in scope.columns:
 89                table_name = col.table
 90                col_name = col.name
 91                selects[table_name].add(col_name)
 92
 93            # Push the selected columns down to the next scope
 94            for name, (node, source) in scope.selected_sources.items():
 95                if isinstance(source, Scope):
 96                    columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set()
 97                    referenced_columns[source].update(columns)
 98
 99                column_aliases = node.alias_column_names
100                if column_aliases:
101                    source_column_alias_count[source] = len(column_aliases)
102
103    return expression

Rewrite sqlglot AST to remove unused columns projections.

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_projections(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
  • remove_unused_selections (bool): remove selects that are unused
Returns:

sqlglot.Expression: optimized expression