sqlglot.optimizer.pushdown_projections
1from __future__ import annotations 2 3import typing as t 4from collections import defaultdict 5 6from sqlglot import alias, exp 7from sqlglot.optimizer.qualify_columns import Resolver 8from sqlglot.optimizer.scope import Scope, traverse_scope 9from sqlglot.schema import ensure_schema 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import seq_get 12 13if t.TYPE_CHECKING: 14 from sqlglot._typing import E 15 from sqlglot.schema import Schema 16 from sqlglot.dialects.dialect import DialectType 17 18# Sentinel value that means an outer query selecting ALL columns 19SELECT_ALL = object() 20 21 22# Selection to use if selection list is empty 23def default_selection(is_agg: bool) -> exp.Alias: 24 return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_").assert_is(exp.Alias) 25 26 27def pushdown_projections( 28 expression: E, 29 schema: dict[str, object] | Schema | None = None, 30 remove_unused_selections: bool = True, 31 dialect: DialectType = None, 32) -> E: 33 """ 34 Rewrite sqlglot AST to remove unused columns projections. 35 36 Example: 37 >>> import sqlglot 38 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 39 >>> expression = sqlglot.parse_one(sql) 40 >>> pushdown_projections(expression).sql() 41 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 42 43 Args: 44 expression (sqlglot.Expr): expression to optimize 45 remove_unused_selections (bool): remove selects that are unused 46 Returns: 47 sqlglot.Expr: optimized expression 48 """ 49 # Map of Scope to all columns being selected by outer queries. 50 schema = ensure_schema(schema, dialect=dialect) 51 source_column_alias_count: dict[exp.Expr | Scope, int] = {} 52 referenced_columns: defaultdict[Scope, set[str | object]] = defaultdict(set) 53 54 # We build the scope tree (which is traversed in DFS postorder), then iterate 55 # over the result in reverse order. This should ensure that the set of selected 56 # columns for a particular scope are completely build by the time we get to it. 57 for scope in reversed(traverse_scope(expression)): 58 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 59 alias_count = source_column_alias_count.get(scope, 0) 60 61 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. 62 if scope.expression.args.get("distinct"): 63 parent_selections = {SELECT_ALL} 64 65 if isinstance(scope.expression, exp.SetOperation): 66 set_op = scope.expression 67 if set_op.kind or set_op.side: 68 # Do not optimize this set operation if it's using the BigQuery specific 69 # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation 70 continue 71 72 left, right = scope.union_scopes 73 le = left.expression 74 re = right.expression 75 76 if not (isinstance(le, exp.Selectable) and isinstance(re, exp.Selectable)): 77 continue 78 79 if len(le.selects) != len(re.selects): 80 scope_sql = scope.expression.sql(dialect=dialect) 81 raise OptimizeError(f"Invalid set operation due to column mismatch: {scope_sql}.") 82 83 referenced_columns[left] = parent_selections 84 85 if re.is_star: 86 referenced_columns[right] = parent_selections 87 elif not le.is_star: 88 if scope.expression.args.get("by_name"): 89 referenced_columns[right] = referenced_columns[left] 90 else: 91 referenced_columns[right] = { 92 re.selects[i].alias_or_name 93 for i, select in enumerate(le.selects) 94 if SELECT_ALL in parent_selections 95 or select.alias_or_name in parent_selections 96 } 97 98 if isinstance(scope.expression, exp.Select): 99 if remove_unused_selections: 100 _remove_unused_selections(scope, parent_selections, schema, alias_count) 101 102 if scope.expression.is_star: 103 continue 104 105 # Group columns by source name 106 selects: dict[str, set[object]] = defaultdict(set) 107 for col in scope.columns: 108 table_name = col.table 109 col_name = col.name 110 selects[table_name].add(col_name) 111 112 # Push the selected columns down to the next scope 113 for name, (node, source) in scope.selected_sources.items(): 114 if isinstance(source, Scope) and isinstance(source.expression, exp.Selectable): 115 select = seq_get(source.expression.selects, 0) 116 117 if scope.pivots or isinstance(select, exp.QueryTransform): 118 columns: set[object] = {SELECT_ALL} 119 else: 120 columns = selects.get(name) or set() 121 122 referenced_columns[source].update(columns) 123 124 column_aliases = node.alias_column_names 125 if column_aliases: 126 source_column_alias_count[source] = len(column_aliases) 127 128 return expression 129 130 131def _remove_unused_selections(scope, parent_selections, schema, alias_count): 132 order = scope.expression.args.get("order") 133 134 if order: 135 # Assume columns without a qualified table are references to output columns 136 order_refs = {c.name for c in order.find_all(exp.Column) if not c.table} 137 else: 138 order_refs = set() 139 140 new_selections = [] 141 removed = False 142 star = False 143 is_agg = False 144 145 select_all = SELECT_ALL in parent_selections 146 147 for selection in scope.expression.selects: 148 name = selection.alias_or_name 149 150 if select_all or name in parent_selections or name in order_refs or alias_count > 0: 151 new_selections.append(selection) 152 alias_count -= 1 153 else: 154 if selection.is_star: 155 star = True 156 removed = True 157 158 if not is_agg and selection.find(exp.AggFunc): 159 is_agg = True 160 161 if star: 162 resolver = Resolver(scope, schema) 163 names = {s.alias_or_name for s in new_selections} 164 165 for name in sorted(parent_selections): 166 if name not in names: 167 new_selections.append( 168 alias(exp.column(name, table=resolver.get_table(name)), name, copy=False) 169 ) 170 171 # If there are no remaining selections, just select a single constant 172 if not new_selections: 173 new_selections.append(default_selection(is_agg)) 174 175 scope.expression.select(*new_selections, append=False, copy=False) 176 177 if removed: 178 scope.clear_cache()
SELECT_ALL =
<object object>
def
pushdown_projections( expression: ~E, schema: dict[str, object] | sqlglot.schema.Schema | None = None, remove_unused_selections: bool = True, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None) -> ~E:
28def pushdown_projections( 29 expression: E, 30 schema: dict[str, object] | Schema | None = None, 31 remove_unused_selections: bool = True, 32 dialect: DialectType = None, 33) -> E: 34 """ 35 Rewrite sqlglot AST to remove unused columns projections. 36 37 Example: 38 >>> import sqlglot 39 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 40 >>> expression = sqlglot.parse_one(sql) 41 >>> pushdown_projections(expression).sql() 42 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 43 44 Args: 45 expression (sqlglot.Expr): expression to optimize 46 remove_unused_selections (bool): remove selects that are unused 47 Returns: 48 sqlglot.Expr: optimized expression 49 """ 50 # Map of Scope to all columns being selected by outer queries. 51 schema = ensure_schema(schema, dialect=dialect) 52 source_column_alias_count: dict[exp.Expr | Scope, int] = {} 53 referenced_columns: defaultdict[Scope, set[str | object]] = defaultdict(set) 54 55 # We build the scope tree (which is traversed in DFS postorder), then iterate 56 # over the result in reverse order. This should ensure that the set of selected 57 # columns for a particular scope are completely build by the time we get to it. 58 for scope in reversed(traverse_scope(expression)): 59 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 60 alias_count = source_column_alias_count.get(scope, 0) 61 62 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. 63 if scope.expression.args.get("distinct"): 64 parent_selections = {SELECT_ALL} 65 66 if isinstance(scope.expression, exp.SetOperation): 67 set_op = scope.expression 68 if set_op.kind or set_op.side: 69 # Do not optimize this set operation if it's using the BigQuery specific 70 # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation 71 continue 72 73 left, right = scope.union_scopes 74 le = left.expression 75 re = right.expression 76 77 if not (isinstance(le, exp.Selectable) and isinstance(re, exp.Selectable)): 78 continue 79 80 if len(le.selects) != len(re.selects): 81 scope_sql = scope.expression.sql(dialect=dialect) 82 raise OptimizeError(f"Invalid set operation due to column mismatch: {scope_sql}.") 83 84 referenced_columns[left] = parent_selections 85 86 if re.is_star: 87 referenced_columns[right] = parent_selections 88 elif not le.is_star: 89 if scope.expression.args.get("by_name"): 90 referenced_columns[right] = referenced_columns[left] 91 else: 92 referenced_columns[right] = { 93 re.selects[i].alias_or_name 94 for i, select in enumerate(le.selects) 95 if SELECT_ALL in parent_selections 96 or select.alias_or_name in parent_selections 97 } 98 99 if isinstance(scope.expression, exp.Select): 100 if remove_unused_selections: 101 _remove_unused_selections(scope, parent_selections, schema, alias_count) 102 103 if scope.expression.is_star: 104 continue 105 106 # Group columns by source name 107 selects: dict[str, set[object]] = defaultdict(set) 108 for col in scope.columns: 109 table_name = col.table 110 col_name = col.name 111 selects[table_name].add(col_name) 112 113 # Push the selected columns down to the next scope 114 for name, (node, source) in scope.selected_sources.items(): 115 if isinstance(source, Scope) and isinstance(source.expression, exp.Selectable): 116 select = seq_get(source.expression.selects, 0) 117 118 if scope.pivots or isinstance(select, exp.QueryTransform): 119 columns: set[object] = {SELECT_ALL} 120 else: 121 columns = selects.get(name) or set() 122 123 referenced_columns[source].update(columns) 124 125 column_aliases = node.alias_column_names 126 if column_aliases: 127 source_column_alias_count[source] = len(column_aliases) 128 129 return expression
Rewrite sqlglot AST to remove unused columns projections.
Example:
>>> import sqlglot >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" >>> expression = sqlglot.parse_one(sql) >>> pushdown_projections(expression).sql() 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
- expression (sqlglot.Expr): expression to optimize
- remove_unused_selections (bool): remove selects that are unused
Returns:
sqlglot.Expr: optimized expression