sqlglot.optimizer.pushdown_projections
1from collections import defaultdict 2 3from sqlglot import alias, exp 4from sqlglot.optimizer.qualify_columns import Resolver 5from sqlglot.optimizer.scope import Scope, traverse_scope 6from sqlglot.schema import ensure_schema 7from sqlglot.errors import OptimizeError 8from sqlglot.helper import seq_get 9 10# Sentinel value that means an outer query selecting ALL columns 11SELECT_ALL = object() 12 13 14# Selection to use if selection list is empty 15def default_selection(is_agg: bool) -> exp.Alias: 16 return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_") 17 18 19def pushdown_projections(expression, schema=None, remove_unused_selections=True): 20 """ 21 Rewrite sqlglot AST to remove unused columns projections. 22 23 Example: 24 >>> import sqlglot 25 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 26 >>> expression = sqlglot.parse_one(sql) 27 >>> pushdown_projections(expression).sql() 28 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 29 30 Args: 31 expression (sqlglot.Expression): expression to optimize 32 remove_unused_selections (bool): remove selects that are unused 33 Returns: 34 sqlglot.Expression: optimized expression 35 """ 36 # Map of Scope to all columns being selected by outer queries. 37 schema = ensure_schema(schema) 38 source_column_alias_count = {} 39 referenced_columns = defaultdict(set) 40 41 # We build the scope tree (which is traversed in DFS postorder), then iterate 42 # over the result in reverse order. This should ensure that the set of selected 43 # columns for a particular scope are completely build by the time we get to it. 44 for scope in reversed(traverse_scope(expression)): 45 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 46 alias_count = source_column_alias_count.get(scope, 0) 47 48 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. 49 if scope.expression.args.get("distinct"): 50 parent_selections = {SELECT_ALL} 51 52 if isinstance(scope.expression, exp.SetOperation): 53 set_op = scope.expression 54 if not (set_op.kind or set_op.side): 55 # Do not optimize this set operation if it's using the BigQuery specific 56 # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation 57 left, right = scope.union_scopes 58 if len(left.expression.selects) != len(right.expression.selects): 59 scope_sql = scope.expression.sql() 60 raise OptimizeError( 61 f"Invalid set operation due to column mismatch: {scope_sql}." 62 ) 63 64 referenced_columns[left] = parent_selections 65 66 if any(select.is_star for select in right.expression.selects): 67 referenced_columns[right] = parent_selections 68 elif not any(select.is_star for select in left.expression.selects): 69 if scope.expression.args.get("by_name"): 70 referenced_columns[right] = referenced_columns[left] 71 else: 72 referenced_columns[right] = [ 73 right.expression.selects[i].alias_or_name 74 for i, select in enumerate(left.expression.selects) 75 if SELECT_ALL in parent_selections 76 or select.alias_or_name in parent_selections 77 ] 78 79 if isinstance(scope.expression, exp.Select): 80 if remove_unused_selections: 81 _remove_unused_selections(scope, parent_selections, schema, alias_count) 82 83 if scope.expression.is_star: 84 continue 85 86 # Group columns by source name 87 selects = defaultdict(set) 88 for col in scope.columns: 89 table_name = col.table 90 col_name = col.name 91 selects[table_name].add(col_name) 92 93 # Push the selected columns down to the next scope 94 for name, (node, source) in scope.selected_sources.items(): 95 if isinstance(source, Scope): 96 select = seq_get(source.expression.selects, 0) 97 98 if scope.pivots or isinstance(select, exp.QueryTransform): 99 columns = {SELECT_ALL} 100 else: 101 columns = selects.get(name) or set() 102 103 referenced_columns[source].update(columns) 104 105 column_aliases = node.alias_column_names 106 if column_aliases: 107 source_column_alias_count[source] = len(column_aliases) 108 109 return expression 110 111 112def _remove_unused_selections(scope, parent_selections, schema, alias_count): 113 order = scope.expression.args.get("order") 114 115 if order: 116 # Assume columns without a qualified table are references to output columns 117 order_refs = {c.name for c in order.find_all(exp.Column) if not c.table} 118 else: 119 order_refs = set() 120 121 new_selections = [] 122 removed = False 123 star = False 124 is_agg = False 125 126 select_all = SELECT_ALL in parent_selections 127 128 for selection in scope.expression.selects: 129 name = selection.alias_or_name 130 131 if select_all or name in parent_selections or name in order_refs or alias_count > 0: 132 new_selections.append(selection) 133 alias_count -= 1 134 else: 135 if selection.is_star: 136 star = True 137 removed = True 138 139 if not is_agg and selection.find(exp.AggFunc): 140 is_agg = True 141 142 if star: 143 resolver = Resolver(scope, schema) 144 names = {s.alias_or_name for s in new_selections} 145 146 for name in sorted(parent_selections): 147 if name not in names: 148 new_selections.append( 149 alias(exp.column(name, table=resolver.get_table(name)), name, copy=False) 150 ) 151 152 # If there are no remaining selections, just select a single constant 153 if not new_selections: 154 new_selections.append(default_selection(is_agg)) 155 156 scope.expression.select(*new_selections, append=False, copy=False) 157 158 if removed: 159 scope.clear_cache()
SELECT_ALL =
<object object>
def
pushdown_projections(expression, schema=None, remove_unused_selections=True):
20def pushdown_projections(expression, schema=None, remove_unused_selections=True): 21 """ 22 Rewrite sqlglot AST to remove unused columns projections. 23 24 Example: 25 >>> import sqlglot 26 >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" 27 >>> expression = sqlglot.parse_one(sql) 28 >>> pushdown_projections(expression).sql() 29 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' 30 31 Args: 32 expression (sqlglot.Expression): expression to optimize 33 remove_unused_selections (bool): remove selects that are unused 34 Returns: 35 sqlglot.Expression: optimized expression 36 """ 37 # Map of Scope to all columns being selected by outer queries. 38 schema = ensure_schema(schema) 39 source_column_alias_count = {} 40 referenced_columns = defaultdict(set) 41 42 # We build the scope tree (which is traversed in DFS postorder), then iterate 43 # over the result in reverse order. This should ensure that the set of selected 44 # columns for a particular scope are completely build by the time we get to it. 45 for scope in reversed(traverse_scope(expression)): 46 parent_selections = referenced_columns.get(scope, {SELECT_ALL}) 47 alias_count = source_column_alias_count.get(scope, 0) 48 49 # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. 50 if scope.expression.args.get("distinct"): 51 parent_selections = {SELECT_ALL} 52 53 if isinstance(scope.expression, exp.SetOperation): 54 set_op = scope.expression 55 if not (set_op.kind or set_op.side): 56 # Do not optimize this set operation if it's using the BigQuery specific 57 # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation 58 left, right = scope.union_scopes 59 if len(left.expression.selects) != len(right.expression.selects): 60 scope_sql = scope.expression.sql() 61 raise OptimizeError( 62 f"Invalid set operation due to column mismatch: {scope_sql}." 63 ) 64 65 referenced_columns[left] = parent_selections 66 67 if any(select.is_star for select in right.expression.selects): 68 referenced_columns[right] = parent_selections 69 elif not any(select.is_star for select in left.expression.selects): 70 if scope.expression.args.get("by_name"): 71 referenced_columns[right] = referenced_columns[left] 72 else: 73 referenced_columns[right] = [ 74 right.expression.selects[i].alias_or_name 75 for i, select in enumerate(left.expression.selects) 76 if SELECT_ALL in parent_selections 77 or select.alias_or_name in parent_selections 78 ] 79 80 if isinstance(scope.expression, exp.Select): 81 if remove_unused_selections: 82 _remove_unused_selections(scope, parent_selections, schema, alias_count) 83 84 if scope.expression.is_star: 85 continue 86 87 # Group columns by source name 88 selects = defaultdict(set) 89 for col in scope.columns: 90 table_name = col.table 91 col_name = col.name 92 selects[table_name].add(col_name) 93 94 # Push the selected columns down to the next scope 95 for name, (node, source) in scope.selected_sources.items(): 96 if isinstance(source, Scope): 97 select = seq_get(source.expression.selects, 0) 98 99 if scope.pivots or isinstance(select, exp.QueryTransform): 100 columns = {SELECT_ALL} 101 else: 102 columns = selects.get(name) or set() 103 104 referenced_columns[source].update(columns) 105 106 column_aliases = node.alias_column_names 107 if column_aliases: 108 source_column_alias_count[source] = len(column_aliases) 109 110 return expression
Rewrite sqlglot AST to remove unused columns projections.
Example:
>>> import sqlglot >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" >>> expression = sqlglot.parse_one(sql) >>> pushdown_projections(expression).sql() 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
- expression (sqlglot.Expression): expression to optimize
- remove_unused_selections (bool): remove selects that are unused
Returns:
sqlglot.Expression: optimized expression