sqlglot.optimizer.pushdown_projections
1from collections import defaultdict 2 3from sqlglot import alias, exp 4from sqlglot.optimizer.qualify_columns import Resolver 5from sqlglot.optimizer.scope import Scope, traverse_scope 6from sqlglot.schema import ensure_schema 7 8# Sentinel value that means an outer query selecting ALL columns 9SELECT_ALL = object() 10 11 12# Selection to use if selection list is empty 13def default_selection(is_agg: bool) -> exp.Alias: 14 return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_") 15 16 17def pushdown_projections(expression, schema=None, remove_unused_selections=True): 18 """ 19 Rewrite sqlglot AST to remove unused columns projections. 20 21 Example: 22 >>> import sqlglot 23 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 24 >>> expression = sqlglot.parse_one(sql) 25 >>> pushdown_projections(expression).sql() 26 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 27 28 Args: 29 expression (sqlglot.Expression): expression to optimize 30 remove_unused_selections (bool): remove selects that are unused 31 Returns: 32 sqlglot.Expression: optimized expression 33 """ 34 # Map of Scope to all columns being selected by outer queries. 35 schema = ensure_schema(schema) 36 source_column_alias_count = {} 37 referenced_columns = defaultdict(set) 38 39 # We build the scope tree (which is traversed in DFS postorder), then iterate 40 # over the result in reverse order. This should ensure that the set of selected 41 # columns for a particular scope are completely build by the time we get to it. 42 for scope in reversed(traverse_scope(expression)): 43 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 44 alias_count = source_column_alias_count.get(scope, 0) 45 46 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. 47 if scope.expression.args.get("distinct"): 48 parent_selections = {SELECT_ALL} 49 50 if isinstance(scope.expression, exp.SetOperation): 51 left, right = scope.union_scopes 52 referenced_columns[left] = parent_selections 53 54 if any(select.is_star for select in right.expression.selects): 55 referenced_columns[right] = parent_selections 56 elif not any(select.is_star for select in left.expression.selects): 57 if scope.expression.args.get("by_name"): 58 referenced_columns[right] = referenced_columns[left] 59 else: 60 referenced_columns[right] = [ 61 right.expression.selects[i].alias_or_name 62 for i, select in enumerate(left.expression.selects) 63 if SELECT_ALL in parent_selections 64 or select.alias_or_name in parent_selections 65 ] 66 67 if isinstance(scope.expression, exp.Select): 68 if remove_unused_selections: 69 _remove_unused_selections(scope, parent_selections, schema, alias_count) 70 71 if scope.expression.is_star: 72 continue 73 74 # Group columns by source name 75 selects = defaultdict(set) 76 for col in scope.columns: 77 table_name = col.table 78 col_name = col.name 79 selects[table_name].add(col_name) 80 81 # Push the selected columns down to the next scope 82 for name, (node, source) in scope.selected_sources.items(): 83 if isinstance(source, Scope): 84 columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set() 85 referenced_columns[source].update(columns) 86 87 column_aliases = node.alias_column_names 88 if column_aliases: 89 source_column_alias_count[source] = len(column_aliases) 90 91 return expression 92 93 94def _remove_unused_selections(scope, parent_selections, schema, alias_count): 95 order = scope.expression.args.get("order") 96 97 if order: 98 # Assume columns without a qualified table are references to output columns 99 order_refs = {c.name for c in order.find_all(exp.Column) if not c.table} 100 else: 101 order_refs = set() 102 103 new_selections = [] 104 removed = False 105 star = False 106 is_agg = False 107 108 select_all = SELECT_ALL in parent_selections 109 110 for selection in scope.expression.selects: 111 name = selection.alias_or_name 112 113 if select_all or name in parent_selections or name in order_refs or alias_count > 0: 114 new_selections.append(selection) 115 alias_count -= 1 116 else: 117 if selection.is_star: 118 star = True 119 removed = True 120 121 if not is_agg and selection.find(exp.AggFunc): 122 is_agg = True 123 124 if star: 125 resolver = Resolver(scope, schema) 126 names = {s.alias_or_name for s in new_selections} 127 128 for name in sorted(parent_selections): 129 if name not in names: 130 new_selections.append( 131 alias(exp.column(name, table=resolver.get_table(name)), name, copy=False) 132 ) 133 134 # If there are no remaining selections, just select a single constant 135 if not new_selections: 136 new_selections.append(default_selection(is_agg)) 137 138 scope.expression.select(*new_selections, append=False, copy=False) 139 140 if removed: 141 scope.clear_cache()
SELECT_ALL =
<object object>
def
pushdown_projections(expression, schema=None, remove_unused_selections=True):
18def pushdown_projections(expression, schema=None, remove_unused_selections=True): 19 """ 20 Rewrite sqlglot AST to remove unused columns projections. 21 22 Example: 23 >>> import sqlglot 24 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 25 >>> expression = sqlglot.parse_one(sql) 26 >>> pushdown_projections(expression).sql() 27 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 28 29 Args: 30 expression (sqlglot.Expression): expression to optimize 31 remove_unused_selections (bool): remove selects that are unused 32 Returns: 33 sqlglot.Expression: optimized expression 34 """ 35 # Map of Scope to all columns being selected by outer queries. 36 schema = ensure_schema(schema) 37 source_column_alias_count = {} 38 referenced_columns = defaultdict(set) 39 40 # We build the scope tree (which is traversed in DFS postorder), then iterate 41 # over the result in reverse order. This should ensure that the set of selected 42 # columns for a particular scope are completely build by the time we get to it. 43 for scope in reversed(traverse_scope(expression)): 44 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 45 alias_count = source_column_alias_count.get(scope, 0) 46 47 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. 48 if scope.expression.args.get("distinct"): 49 parent_selections = {SELECT_ALL} 50 51 if isinstance(scope.expression, exp.SetOperation): 52 left, right = scope.union_scopes 53 referenced_columns[left] = parent_selections 54 55 if any(select.is_star for select in right.expression.selects): 56 referenced_columns[right] = parent_selections 57 elif not any(select.is_star for select in left.expression.selects): 58 if scope.expression.args.get("by_name"): 59 referenced_columns[right] = referenced_columns[left] 60 else: 61 referenced_columns[right] = [ 62 right.expression.selects[i].alias_or_name 63 for i, select in enumerate(left.expression.selects) 64 if SELECT_ALL in parent_selections 65 or select.alias_or_name in parent_selections 66 ] 67 68 if isinstance(scope.expression, exp.Select): 69 if remove_unused_selections: 70 _remove_unused_selections(scope, parent_selections, schema, alias_count) 71 72 if scope.expression.is_star: 73 continue 74 75 # Group columns by source name 76 selects = defaultdict(set) 77 for col in scope.columns: 78 table_name = col.table 79 col_name = col.name 80 selects[table_name].add(col_name) 81 82 # Push the selected columns down to the next scope 83 for name, (node, source) in scope.selected_sources.items(): 84 if isinstance(source, Scope): 85 columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set() 86 referenced_columns[source].update(columns) 87 88 column_aliases = node.alias_column_names 89 if column_aliases: 90 source_column_alias_count[source] = len(column_aliases) 91 92 return expression
Rewrite sqlglot AST to remove unused columns projections.
Example:
>>> import sqlglot >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" >>> expression = sqlglot.parse_one(sql) >>> pushdown_projections(expression).sql() 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
- expression (sqlglot.Expression): expression to optimize
- remove_unused_selections (bool): remove selects that are unused
Returns:
sqlglot.Expression: optimized expression