sqlglot.optimizer.pushdown_projections
1from collections import defaultdict 2 3from sqlglot import alias, exp 4from sqlglot.optimizer.qualify_columns import Resolver 5from sqlglot.optimizer.scope import Scope, traverse_scope 6from sqlglot.schema import ensure_schema 7from sqlglot.errors import OptimizeError 8 9# Sentinel value that means an outer query selecting ALL columns 10SELECT_ALL = object() 11 12 13# Selection to use if selection list is empty 14def default_selection(is_agg: bool) -> exp.Alias: 15 return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_") 16 17 18def pushdown_projections(expression, schema=None, remove_unused_selections=True): 19 """ 20 Rewrite sqlglot AST to remove unused columns projections. 21 22 Example: 23 >>> import sqlglot 24 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 25 >>> expression = sqlglot.parse_one(sql) 26 >>> pushdown_projections(expression).sql() 27 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 28 29 Args: 30 expression (sqlglot.Expression): expression to optimize 31 remove_unused_selections (bool): remove selects that are unused 32 Returns: 33 sqlglot.Expression: optimized expression 34 """ 35 # Map of Scope to all columns being selected by outer queries. 36 schema = ensure_schema(schema) 37 source_column_alias_count = {} 38 referenced_columns = defaultdict(set) 39 40 # We build the scope tree (which is traversed in DFS postorder), then iterate 41 # over the result in reverse order. This should ensure that the set of selected 42 # columns for a particular scope are completely build by the time we get to it. 43 for scope in reversed(traverse_scope(expression)): 44 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 45 alias_count = source_column_alias_count.get(scope, 0) 46 47 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. 48 if scope.expression.args.get("distinct"): 49 parent_selections = {SELECT_ALL} 50 51 if isinstance(scope.expression, exp.SetOperation): 52 set_op = scope.expression 53 if not (set_op.kind or set_op.side): 54 # Do not optimize this set operation if it's using the BigQuery specific 55 # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation 56 left, right = scope.union_scopes 57 if len(left.expression.selects) != len(right.expression.selects): 58 scope_sql = scope.expression.sql() 59 raise OptimizeError( 60 f"Invalid set operation due to column mismatch: {scope_sql}." 61 ) 62 63 referenced_columns[left] = parent_selections 64 65 if any(select.is_star for select in right.expression.selects): 66 referenced_columns[right] = parent_selections 67 elif not any(select.is_star for select in left.expression.selects): 68 if scope.expression.args.get("by_name"): 69 referenced_columns[right] = referenced_columns[left] 70 else: 71 referenced_columns[right] = [ 72 right.expression.selects[i].alias_or_name 73 for i, select in enumerate(left.expression.selects) 74 if SELECT_ALL in parent_selections 75 or select.alias_or_name in parent_selections 76 ] 77 78 if isinstance(scope.expression, exp.Select): 79 if remove_unused_selections: 80 _remove_unused_selections(scope, parent_selections, schema, alias_count) 81 82 if scope.expression.is_star: 83 continue 84 85 # Group columns by source name 86 selects = defaultdict(set) 87 for col in scope.columns: 88 table_name = col.table 89 col_name = col.name 90 selects[table_name].add(col_name) 91 92 # Push the selected columns down to the next scope 93 for name, (node, source) in scope.selected_sources.items(): 94 if isinstance(source, Scope): 95 columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set() 96 referenced_columns[source].update(columns) 97 98 column_aliases = node.alias_column_names 99 if column_aliases: 100 source_column_alias_count[source] = len(column_aliases) 101 102 return expression 103 104 105def _remove_unused_selections(scope, parent_selections, schema, alias_count): 106 order = scope.expression.args.get("order") 107 108 if order: 109 # Assume columns without a qualified table are references to output columns 110 order_refs = {c.name for c in order.find_all(exp.Column) if not c.table} 111 else: 112 order_refs = set() 113 114 new_selections = [] 115 removed = False 116 star = False 117 is_agg = False 118 119 select_all = SELECT_ALL in parent_selections 120 121 for selection in scope.expression.selects: 122 name = selection.alias_or_name 123 124 if select_all or name in parent_selections or name in order_refs or alias_count > 0: 125 new_selections.append(selection) 126 alias_count -= 1 127 else: 128 if selection.is_star: 129 star = True 130 removed = True 131 132 if not is_agg and selection.find(exp.AggFunc): 133 is_agg = True 134 135 if star: 136 resolver = Resolver(scope, schema) 137 names = {s.alias_or_name for s in new_selections} 138 139 for name in sorted(parent_selections): 140 if name not in names: 141 new_selections.append( 142 alias(exp.column(name, table=resolver.get_table(name)), name, copy=False) 143 ) 144 145 # If there are no remaining selections, just select a single constant 146 if not new_selections: 147 new_selections.append(default_selection(is_agg)) 148 149 scope.expression.select(*new_selections, append=False, copy=False) 150 151 if removed: 152 scope.clear_cache()
SELECT_ALL =
<object object>
def
pushdown_projections(expression, schema=None, remove_unused_selections=True):
19def pushdown_projections(expression, schema=None, remove_unused_selections=True): 20 """ 21 Rewrite sqlglot AST to remove unused columns projections. 22 23 Example: 24 >>> import sqlglot 25 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 26 >>> expression = sqlglot.parse_one(sql) 27 >>> pushdown_projections(expression).sql() 28 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 29 30 Args: 31 expression (sqlglot.Expression): expression to optimize 32 remove_unused_selections (bool): remove selects that are unused 33 Returns: 34 sqlglot.Expression: optimized expression 35 """ 36 # Map of Scope to all columns being selected by outer queries. 37 schema = ensure_schema(schema) 38 source_column_alias_count = {} 39 referenced_columns = defaultdict(set) 40 41 # We build the scope tree (which is traversed in DFS postorder), then iterate 42 # over the result in reverse order. This should ensure that the set of selected 43 # columns for a particular scope are completely build by the time we get to it. 44 for scope in reversed(traverse_scope(expression)): 45 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 46 alias_count = source_column_alias_count.get(scope, 0) 47 48 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. 49 if scope.expression.args.get("distinct"): 50 parent_selections = {SELECT_ALL} 51 52 if isinstance(scope.expression, exp.SetOperation): 53 set_op = scope.expression 54 if not (set_op.kind or set_op.side): 55 # Do not optimize this set operation if it's using the BigQuery specific 56 # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation 57 left, right = scope.union_scopes 58 if len(left.expression.selects) != len(right.expression.selects): 59 scope_sql = scope.expression.sql() 60 raise OptimizeError( 61 f"Invalid set operation due to column mismatch: {scope_sql}." 62 ) 63 64 referenced_columns[left] = parent_selections 65 66 if any(select.is_star for select in right.expression.selects): 67 referenced_columns[right] = parent_selections 68 elif not any(select.is_star for select in left.expression.selects): 69 if scope.expression.args.get("by_name"): 70 referenced_columns[right] = referenced_columns[left] 71 else: 72 referenced_columns[right] = [ 73 right.expression.selects[i].alias_or_name 74 for i, select in enumerate(left.expression.selects) 75 if SELECT_ALL in parent_selections 76 or select.alias_or_name in parent_selections 77 ] 78 79 if isinstance(scope.expression, exp.Select): 80 if remove_unused_selections: 81 _remove_unused_selections(scope, parent_selections, schema, alias_count) 82 83 if scope.expression.is_star: 84 continue 85 86 # Group columns by source name 87 selects = defaultdict(set) 88 for col in scope.columns: 89 table_name = col.table 90 col_name = col.name 91 selects[table_name].add(col_name) 92 93 # Push the selected columns down to the next scope 94 for name, (node, source) in scope.selected_sources.items(): 95 if isinstance(source, Scope): 96 columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set() 97 referenced_columns[source].update(columns) 98 99 column_aliases = node.alias_column_names 100 if column_aliases: 101 source_column_alias_count[source] = len(column_aliases) 102 103 return expression
Rewrite sqlglot AST to remove unused columns projections.
Example:
>>> import sqlglot >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" >>> expression = sqlglot.parse_one(sql) >>> pushdown_projections(expression).sql() 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
- expression (sqlglot.Expression): expression to optimize
- remove_unused_selections (bool): remove selects that are unused
Returns:
sqlglot.Expression: optimized expression