sqlglot.optimizer.pushdown_projections
1from __future__ import annotations 2 3import typing as t 4from collections import defaultdict 5 6from sqlglot import alias, exp 7from sqlglot.optimizer.qualify_columns import Resolver 8from sqlglot.optimizer.scope import Scope, traverse_scope 9from sqlglot.schema import ensure_schema 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import seq_get 12 13if t.TYPE_CHECKING: 14 from sqlglot._typing import E 15 from sqlglot.schema import Schema 16 from sqlglot.dialects.dialect import DialectType 17 18# Sentinel value that means an outer query selecting ALL columns 19SELECT_ALL = object() 20 21 22# Selection to use if selection list is empty 23def default_selection(is_agg: bool) -> exp.Alias: 24 return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_") 25 26 27def pushdown_projections( 28 expression: E, 29 schema: t.Optional[t.Dict | Schema] = None, 30 remove_unused_selections: bool = True, 31 dialect: DialectType = None, 32) -> E: 33 """ 34 Rewrite sqlglot AST to remove unused columns projections. 35 36 Example: 37 >>> import sqlglot 38 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 39 >>> expression = sqlglot.parse_one(sql) 40 >>> pushdown_projections(expression).sql() 41 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 42 43 Args: 44 expression (sqlglot.Expression): expression to optimize 45 remove_unused_selections (bool): remove selects that are unused 46 Returns: 47 sqlglot.Expression: optimized expression 48 """ 49 # Map of Scope to all columns being selected by outer queries. 50 schema = ensure_schema(schema, dialect=dialect) 51 source_column_alias_count: t.Dict[exp.Expression | Scope, int] = {} 52 referenced_columns: t.DefaultDict[Scope, t.Set[str | object]] = defaultdict(set) 53 54 # We build the scope tree (which is traversed in DFS postorder), then iterate 55 # over the result in reverse order. This should ensure that the set of selected 56 # columns for a particular scope are completely build by the time we get to it. 57 for scope in reversed(traverse_scope(expression)): 58 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 59 alias_count = source_column_alias_count.get(scope, 0) 60 61 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. 62 if scope.expression.args.get("distinct"): 63 parent_selections = {SELECT_ALL} 64 65 if isinstance(scope.expression, exp.SetOperation): 66 set_op = scope.expression 67 if not (set_op.kind or set_op.side): 68 # Do not optimize this set operation if it's using the BigQuery specific 69 # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation 70 left, right = scope.union_scopes 71 if len(left.expression.selects) != len(right.expression.selects): 72 scope_sql = scope.expression.sql() 73 raise OptimizeError( 74 f"Invalid set operation due to column mismatch: {scope_sql}." 75 ) 76 77 referenced_columns[left] = parent_selections 78 79 if any(select.is_star for select in right.expression.selects): 80 referenced_columns[right] = parent_selections 81 elif not any(select.is_star for select in left.expression.selects): 82 if scope.expression.args.get("by_name"): 83 referenced_columns[right] = referenced_columns[left] 84 else: 85 referenced_columns[right] = { 86 right.expression.selects[i].alias_or_name 87 for i, select in enumerate(left.expression.selects) 88 if SELECT_ALL in parent_selections 89 or select.alias_or_name in parent_selections 90 } 91 92 if isinstance(scope.expression, exp.Select): 93 if remove_unused_selections: 94 _remove_unused_selections(scope, parent_selections, schema, alias_count) 95 96 if scope.expression.is_star: 97 continue 98 99 # Group columns by source name 100 selects = defaultdict(set) 101 for col in scope.columns: 102 table_name = col.table 103 col_name = col.name 104 selects[table_name].add(col_name) 105 106 # Push the selected columns down to the next scope 107 for name, (node, source) in scope.selected_sources.items(): 108 if isinstance(source, Scope): 109 select = seq_get(source.expression.selects, 0) 110 111 if scope.pivots or isinstance(select, exp.QueryTransform): 112 columns = {SELECT_ALL} 113 else: 114 columns = selects.get(name) or set() 115 116 referenced_columns[source].update(columns) 117 118 column_aliases = node.alias_column_names 119 if column_aliases: 120 source_column_alias_count[source] = len(column_aliases) 121 122 return expression 123 124 125def _remove_unused_selections(scope, parent_selections, schema, alias_count): 126 order = scope.expression.args.get("order") 127 128 if order: 129 # Assume columns without a qualified table are references to output columns 130 order_refs = {c.name for c in order.find_all(exp.Column) if not c.table} 131 else: 132 order_refs = set() 133 134 new_selections = [] 135 removed = False 136 star = False 137 is_agg = False 138 139 select_all = SELECT_ALL in parent_selections 140 141 for selection in scope.expression.selects: 142 name = selection.alias_or_name 143 144 if select_all or name in parent_selections or name in order_refs or alias_count > 0: 145 new_selections.append(selection) 146 alias_count -= 1 147 else: 148 if selection.is_star: 149 star = True 150 removed = True 151 152 if not is_agg and selection.find(exp.AggFunc): 153 is_agg = True 154 155 if star: 156 resolver = Resolver(scope, schema) 157 names = {s.alias_or_name for s in new_selections} 158 159 for name in sorted(parent_selections): 160 if name not in names: 161 new_selections.append( 162 alias(exp.column(name, table=resolver.get_table(name)), name, copy=False) 163 ) 164 165 # If there are no remaining selections, just select a single constant 166 if not new_selections: 167 new_selections.append(default_selection(is_agg)) 168 169 scope.expression.select(*new_selections, append=False, copy=False) 170 171 if removed: 172 scope.clear_cache()
SELECT_ALL =
<object object>
def
pushdown_projections( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, remove_unused_selections: bool = True, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> ~E:
28def pushdown_projections( 29 expression: E, 30 schema: t.Optional[t.Dict | Schema] = None, 31 remove_unused_selections: bool = True, 32 dialect: DialectType = None, 33) -> E: 34 """ 35 Rewrite sqlglot AST to remove unused columns projections. 36 37 Example: 38 >>> import sqlglot 39 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 40 >>> expression = sqlglot.parse_one(sql) 41 >>> pushdown_projections(expression).sql() 42 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 43 44 Args: 45 expression (sqlglot.Expression): expression to optimize 46 remove_unused_selections (bool): remove selects that are unused 47 Returns: 48 sqlglot.Expression: optimized expression 49 """ 50 # Map of Scope to all columns being selected by outer queries. 51 schema = ensure_schema(schema, dialect=dialect) 52 source_column_alias_count: t.Dict[exp.Expression | Scope, int] = {} 53 referenced_columns: t.DefaultDict[Scope, t.Set[str | object]] = defaultdict(set) 54 55 # We build the scope tree (which is traversed in DFS postorder), then iterate 56 # over the result in reverse order. This should ensure that the set of selected 57 # columns for a particular scope are completely build by the time we get to it. 58 for scope in reversed(traverse_scope(expression)): 59 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 60 alias_count = source_column_alias_count.get(scope, 0) 61 62 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. 63 if scope.expression.args.get("distinct"): 64 parent_selections = {SELECT_ALL} 65 66 if isinstance(scope.expression, exp.SetOperation): 67 set_op = scope.expression 68 if not (set_op.kind or set_op.side): 69 # Do not optimize this set operation if it's using the BigQuery specific 70 # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation 71 left, right = scope.union_scopes 72 if len(left.expression.selects) != len(right.expression.selects): 73 scope_sql = scope.expression.sql() 74 raise OptimizeError( 75 f"Invalid set operation due to column mismatch: {scope_sql}." 76 ) 77 78 referenced_columns[left] = parent_selections 79 80 if any(select.is_star for select in right.expression.selects): 81 referenced_columns[right] = parent_selections 82 elif not any(select.is_star for select in left.expression.selects): 83 if scope.expression.args.get("by_name"): 84 referenced_columns[right] = referenced_columns[left] 85 else: 86 referenced_columns[right] = { 87 right.expression.selects[i].alias_or_name 88 for i, select in enumerate(left.expression.selects) 89 if SELECT_ALL in parent_selections 90 or select.alias_or_name in parent_selections 91 } 92 93 if isinstance(scope.expression, exp.Select): 94 if remove_unused_selections: 95 _remove_unused_selections(scope, parent_selections, schema, alias_count) 96 97 if scope.expression.is_star: 98 continue 99 100 # Group columns by source name 101 selects = defaultdict(set) 102 for col in scope.columns: 103 table_name = col.table 104 col_name = col.name 105 selects[table_name].add(col_name) 106 107 # Push the selected columns down to the next scope 108 for name, (node, source) in scope.selected_sources.items(): 109 if isinstance(source, Scope): 110 select = seq_get(source.expression.selects, 0) 111 112 if scope.pivots or isinstance(select, exp.QueryTransform): 113 columns = {SELECT_ALL} 114 else: 115 columns = selects.get(name) or set() 116 117 referenced_columns[source].update(columns) 118 119 column_aliases = node.alias_column_names 120 if column_aliases: 121 source_column_alias_count[source] = len(column_aliases) 122 123 return expression
Rewrite sqlglot AST to remove unused columns projections.
Example:
>>> import sqlglot >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" >>> expression = sqlglot.parse_one(sql) >>> pushdown_projections(expression).sql() 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
- expression (sqlglot.Expression): expression to optimize
- remove_unused_selections (bool): remove selects that are unused
Returns:
sqlglot.Expression: optimized expression