sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25 allow_partial_qualification: bool = False, 26) -> exp.Expression: 27 """ 28 Rewrite sqlglot AST to have fully qualified columns. 29 30 Example: 31 >>> import sqlglot 32 >>> schema = {"tbl": {"col": "INT"}} 33 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 34 >>> qualify_columns(expression, schema).sql() 35 'SELECT tbl.col AS col FROM tbl' 36 37 Args: 38 expression: Expression to qualify. 39 schema: Database schema. 40 expand_alias_refs: Whether to expand references to aliases. 41 expand_stars: Whether to expand star queries. This is a necessary step 42 for most of the optimizer's rules to work; do not set to False unless you 43 know what you're doing! 44 infer_schema: Whether to infer the schema if missing. 45 allow_partial_qualification: Whether to allow partial qualification. 46 47 Returns: 48 The qualified expression. 49 50 Notes: 51 - Currently only handles a single PIVOT or UNPIVOT operator 52 """ 53 schema = ensure_schema(schema) 54 annotator = TypeAnnotator(schema) 55 infer_schema = schema.empty if infer_schema is None else infer_schema 56 dialect = Dialect.get_or_raise(schema.dialect) 57 pseudocolumns = dialect.PSEUDOCOLUMNS 58 59 for scope in traverse_scope(expression): 60 resolver = Resolver(scope, schema, infer_schema=infer_schema) 61 _pop_table_column_aliases(scope.ctes) 62 _pop_table_column_aliases(scope.derived_tables) 63 using_column_tables = _expand_using(scope, resolver) 64 65 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 66 _expand_alias_refs( 67 scope, 68 resolver, 69 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 70 ) 71 72 _convert_columns_to_dots(scope, resolver) 73 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 74 75 if not schema.empty and expand_alias_refs: 76 _expand_alias_refs(scope, resolver) 77 78 if not isinstance(scope.expression, exp.UDTF): 79 if expand_stars: 80 _expand_stars( 81 scope, 82 resolver, 83 using_column_tables, 84 pseudocolumns, 85 annotator, 86 ) 87 qualify_outputs(scope) 88 89 _expand_group_by(scope, dialect) 90 _expand_order_by(scope, resolver) 91 92 if dialect == "bigquery": 93 annotator.annotate_scope(scope) 94 95 return expression 96 97 98def validate_qualify_columns(expression: E) -> E: 99 """Raise an `OptimizeError` if any columns aren't qualified""" 100 all_unqualified_columns = [] 101 for scope in traverse_scope(expression): 102 if isinstance(scope.expression, exp.Select): 103 unqualified_columns = scope.unqualified_columns 104 105 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 106 column = scope.external_columns[0] 107 for_table = f" for table: '{column.table}'" if column.table else "" 108 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 109 110 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 111 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 112 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 113 # this list here to ensure those in the former category will be excluded. 114 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 115 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 116 117 all_unqualified_columns.extend(unqualified_columns) 118 119 if all_unqualified_columns: 120 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 121 122 return expression 123 124 125def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 126 name_column = [] 127 field = unpivot.args.get("field") 128 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 129 name_column.append(field.this) 130 131 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 132 return itertools.chain(name_column, value_columns) 133 134 135def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 136 """ 137 Remove table column aliases. 138 139 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 140 """ 141 for derived_table in derived_tables: 142 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 143 continue 144 table_alias = derived_table.args.get("alias") 145 if table_alias: 146 table_alias.args.pop("columns", None) 147 148 149def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 150 columns = {} 151 152 def _update_source_columns(source_name: str) -> None: 153 for column_name in resolver.get_source_columns(source_name): 154 if column_name not in columns: 155 columns[column_name] = source_name 156 157 joins = list(scope.find_all(exp.Join)) 158 names = {join.alias_or_name for join in joins} 159 ordered = [key for key in scope.selected_sources if key not in names] 160 161 # Mapping of automatically joined column names to an ordered set of source names (dict). 162 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 163 164 for source_name in ordered: 165 _update_source_columns(source_name) 166 167 for i, join in enumerate(joins): 168 source_table = ordered[-1] 169 if source_table: 170 _update_source_columns(source_table) 171 172 join_table = join.alias_or_name 173 ordered.append(join_table) 174 175 using = join.args.get("using") 176 if not using: 177 continue 178 179 join_columns = resolver.get_source_columns(join_table) 180 conditions = [] 181 using_identifier_count = len(using) 182 183 for identifier in using: 184 identifier = identifier.name 185 table = columns.get(identifier) 186 187 if not table or identifier not in join_columns: 188 if (columns and "*" not in columns) and join_columns: 189 raise OptimizeError(f"Cannot automatically join: {identifier}") 190 191 table = table or source_table 192 193 if i == 0 or using_identifier_count == 1: 194 lhs: exp.Expression = exp.column(identifier, table=table) 195 else: 196 coalesce_columns = [ 197 exp.column(identifier, table=t) 198 for t in ordered[:-1] 199 if identifier in resolver.get_source_columns(t) 200 ] 201 if len(coalesce_columns) > 1: 202 lhs = exp.func("coalesce", *coalesce_columns) 203 else: 204 lhs = exp.column(identifier, table=table) 205 206 conditions.append(lhs.eq(exp.column(identifier, table=join_table))) 207 208 # Set all values in the dict to None, because we only care about the key ordering 209 tables = column_tables.setdefault(identifier, {}) 210 if table not in tables: 211 tables[table] = None 212 if join_table not in tables: 213 tables[join_table] = None 214 215 join.args.pop("using") 216 join.set("on", exp.and_(*conditions, copy=False)) 217 218 if column_tables: 219 for column in scope.columns: 220 if not column.table and column.name in column_tables: 221 tables = column_tables[column.name] 222 coalesce_args = [exp.column(column.name, table=table) for table in tables] 223 replacement = exp.func("coalesce", *coalesce_args) 224 225 # Ensure selects keep their output name 226 if isinstance(column.parent, exp.Select): 227 replacement = alias(replacement, alias=column.name, copy=False) 228 229 scope.replace(column, replacement) 230 231 return column_tables 232 233 234def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None: 235 expression = scope.expression 236 237 if not isinstance(expression, exp.Select): 238 return 239 240 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 241 242 def replace_columns( 243 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 244 ) -> None: 245 is_group_by = isinstance(node, exp.Group) 246 if not node or (expand_only_groupby and not is_group_by): 247 return 248 249 for column in walk_in_scope(node, prune=lambda node: node.is_star): 250 if not isinstance(column, exp.Column): 251 continue 252 253 # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g: 254 # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded 255 # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col)) 256 # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns 257 if expand_only_groupby and is_group_by and column.parent is not node: 258 continue 259 260 table = resolver.get_table(column.name) if resolve_table and not column.table else None 261 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 262 double_agg = ( 263 ( 264 alias_expr.find(exp.AggFunc) 265 and ( 266 column.find_ancestor(exp.AggFunc) 267 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 268 ) 269 ) 270 if alias_expr 271 else False 272 ) 273 274 if table and (not alias_expr or double_agg): 275 column.set("table", table) 276 elif not column.table and alias_expr and not double_agg: 277 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 278 if literal_index: 279 column.replace(exp.Literal.number(i)) 280 else: 281 column = column.replace(exp.paren(alias_expr)) 282 simplified = simplify_parens(column) 283 if simplified is not column: 284 column.replace(simplified) 285 286 for i, projection in enumerate(expression.selects): 287 replace_columns(projection) 288 if isinstance(projection, exp.Alias): 289 alias_to_expression[projection.alias] = (projection.this, i + 1) 290 291 parent_scope = scope 292 while parent_scope.is_union: 293 parent_scope = parent_scope.parent 294 295 # We shouldn't expand aliases if they match the recursive CTE's columns 296 if parent_scope.is_cte: 297 cte = parent_scope.expression.parent 298 if cte.find_ancestor(exp.With).recursive: 299 for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: 300 alias_to_expression.pop(recursive_cte_column.output_name, None) 301 302 replace_columns(expression.args.get("where")) 303 replace_columns(expression.args.get("group"), literal_index=True) 304 replace_columns(expression.args.get("having"), resolve_table=True) 305 replace_columns(expression.args.get("qualify"), resolve_table=True) 306 307 scope.clear_cache() 308 309 310def _expand_group_by(scope: Scope, dialect: DialectType) -> None: 311 expression = scope.expression 312 group = expression.args.get("group") 313 if not group: 314 return 315 316 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 317 expression.set("group", group) 318 319 320def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 321 order = scope.expression.args.get("order") 322 if not order: 323 return 324 325 ordereds = order.expressions 326 for ordered, new_expression in zip( 327 ordereds, 328 _expand_positional_references( 329 scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True 330 ), 331 ): 332 for agg in ordered.find_all(exp.AggFunc): 333 for col in agg.find_all(exp.Column): 334 if not col.table: 335 col.set("table", resolver.get_table(col.name)) 336 337 ordered.set("this", new_expression) 338 339 if scope.expression.args.get("group"): 340 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 341 342 for ordered in ordereds: 343 ordered = ordered.this 344 345 ordered.replace( 346 exp.to_identifier(_select_by_pos(scope, ordered).alias) 347 if ordered.is_int 348 else selects.get(ordered, ordered) 349 ) 350 351 352def _expand_positional_references( 353 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False 354) -> t.List[exp.Expression]: 355 new_nodes: t.List[exp.Expression] = [] 356 ambiguous_projections = None 357 358 for node in expressions: 359 if node.is_int: 360 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 361 362 if alias: 363 new_nodes.append(exp.column(select.args["alias"].copy())) 364 else: 365 select = select.this 366 367 if dialect == "bigquery": 368 if ambiguous_projections is None: 369 # When a projection name is also a source name and it is referenced in the 370 # GROUP BY clause, BQ can't understand what the identifier corresponds to 371 ambiguous_projections = { 372 s.alias_or_name 373 for s in scope.expression.selects 374 if s.alias_or_name in scope.selected_sources 375 } 376 377 ambiguous = any( 378 column.parts[0].name in ambiguous_projections 379 for column in select.find_all(exp.Column) 380 ) 381 else: 382 ambiguous = False 383 384 if ( 385 isinstance(select, exp.CONSTANTS) 386 or select.find(exp.Explode, exp.Unnest) 387 or ambiguous 388 ): 389 new_nodes.append(node) 390 else: 391 new_nodes.append(select.copy()) 392 else: 393 new_nodes.append(node) 394 395 return new_nodes 396 397 398def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 399 try: 400 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 401 except IndexError: 402 raise OptimizeError(f"Unknown output column: {node.name}") 403 404 405def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 406 """ 407 Converts `Column` instances that represent struct field lookup into chained `Dots`. 408 409 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 410 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 411 """ 412 converted = False 413 for column in itertools.chain(scope.columns, scope.stars): 414 if isinstance(column, exp.Dot): 415 continue 416 417 column_table: t.Optional[str | exp.Identifier] = column.table 418 if ( 419 column_table 420 and column_table not in scope.sources 421 and ( 422 not scope.parent 423 or column_table not in scope.parent.sources 424 or not scope.is_correlated_subquery 425 ) 426 ): 427 root, *parts = column.parts 428 429 if root.name in scope.sources: 430 # The struct is already qualified, but we still need to change the AST 431 column_table = root 432 root, *parts = parts 433 else: 434 column_table = resolver.get_table(root.name) 435 436 if column_table: 437 converted = True 438 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 439 440 if converted: 441 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 442 # a `for column in scope.columns` iteration, even though they shouldn't be 443 scope.clear_cache() 444 445 446def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None: 447 """Disambiguate columns, ensuring each column specifies a source""" 448 for column in scope.columns: 449 column_table = column.table 450 column_name = column.name 451 452 if column_table and column_table in scope.sources: 453 source_columns = resolver.get_source_columns(column_table) 454 if ( 455 not allow_partial_qualification 456 and source_columns 457 and column_name not in source_columns 458 and "*" not in source_columns 459 ): 460 raise OptimizeError(f"Unknown column: {column_name}") 461 462 if not column_table: 463 if scope.pivots and not column.find_ancestor(exp.Pivot): 464 # If the column is under the Pivot expression, we need to qualify it 465 # using the name of the pivoted source instead of the pivot's alias 466 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 467 continue 468 469 # column_table can be a '' because bigquery unnest has no table alias 470 column_table = resolver.get_table(column_name) 471 if column_table: 472 column.set("table", column_table) 473 474 for pivot in scope.pivots: 475 for column in pivot.find_all(exp.Column): 476 if not column.table and column.name in resolver.all_columns: 477 column_table = resolver.get_table(column.name) 478 if column_table: 479 column.set("table", column_table) 480 481 482def _expand_struct_stars( 483 expression: exp.Dot, 484) -> t.List[exp.Alias]: 485 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 486 487 dot_column = t.cast(exp.Column, expression.find(exp.Column)) 488 if not dot_column.is_type(exp.DataType.Type.STRUCT): 489 return [] 490 491 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 492 dot_column = dot_column.copy() 493 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 494 495 # First part is the table name and last part is the star so they can be dropped 496 dot_parts = expression.parts[1:-1] 497 498 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 499 for part in dot_parts[1:]: 500 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 501 # Unable to expand star unless all fields are named 502 if not isinstance(field.this, exp.Identifier): 503 return [] 504 505 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 506 starting_struct = field 507 break 508 else: 509 # There is no matching field in the struct 510 return [] 511 512 taken_names = set() 513 new_selections = [] 514 515 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 516 name = field.name 517 518 # Ambiguous or anonymous fields can't be expanded 519 if name in taken_names or not isinstance(field.this, exp.Identifier): 520 return [] 521 522 taken_names.add(name) 523 524 this = field.this.copy() 525 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 526 new_column = exp.column( 527 t.cast(exp.Identifier, root), table=dot_column.args.get("table"), fields=parts 528 ) 529 new_selections.append(alias(new_column, this, copy=False)) 530 531 return new_selections 532 533 534def _expand_stars( 535 scope: Scope, 536 resolver: Resolver, 537 using_column_tables: t.Dict[str, t.Any], 538 pseudocolumns: t.Set[str], 539 annotator: TypeAnnotator, 540) -> None: 541 """Expand stars to lists of column selections""" 542 543 new_selections: t.List[exp.Expression] = [] 544 except_columns: t.Dict[int, t.Set[str]] = {} 545 replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {} 546 rename_columns: t.Dict[int, t.Dict[str, str]] = {} 547 548 coalesced_columns = set() 549 dialect = resolver.schema.dialect 550 551 pivot_output_columns = None 552 pivot_exclude_columns = None 553 554 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 555 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 556 if pivot.unpivot: 557 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 558 559 field = pivot.args.get("field") 560 if isinstance(field, exp.In): 561 pivot_exclude_columns = { 562 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 563 } 564 else: 565 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 566 567 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 568 if not pivot_output_columns: 569 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 570 571 is_bigquery = dialect == "bigquery" 572 if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars): 573 # Found struct expansion, annotate scope ahead of time 574 annotator.annotate_scope(scope) 575 576 for expression in scope.expression.selects: 577 tables = [] 578 if isinstance(expression, exp.Star): 579 tables.extend(scope.selected_sources) 580 _add_except_columns(expression, tables, except_columns) 581 _add_replace_columns(expression, tables, replace_columns) 582 _add_rename_columns(expression, tables, rename_columns) 583 elif expression.is_star: 584 if not isinstance(expression, exp.Dot): 585 tables.append(expression.table) 586 _add_except_columns(expression.this, tables, except_columns) 587 _add_replace_columns(expression.this, tables, replace_columns) 588 _add_rename_columns(expression.this, tables, rename_columns) 589 elif is_bigquery: 590 struct_fields = _expand_struct_stars(expression) 591 if struct_fields: 592 new_selections.extend(struct_fields) 593 continue 594 595 if not tables: 596 new_selections.append(expression) 597 continue 598 599 for table in tables: 600 if table not in scope.sources: 601 raise OptimizeError(f"Unknown table: {table}") 602 603 columns = resolver.get_source_columns(table, only_visible=True) 604 columns = columns or scope.outer_columns 605 606 if pseudocolumns: 607 columns = [name for name in columns if name.upper() not in pseudocolumns] 608 609 if not columns or "*" in columns: 610 return 611 612 table_id = id(table) 613 columns_to_exclude = except_columns.get(table_id) or set() 614 renamed_columns = rename_columns.get(table_id, {}) 615 replaced_columns = replace_columns.get(table_id, {}) 616 617 if pivot: 618 if pivot_output_columns and pivot_exclude_columns: 619 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 620 pivot_columns.extend(pivot_output_columns) 621 else: 622 pivot_columns = pivot.alias_column_names 623 624 if pivot_columns: 625 new_selections.extend( 626 alias(exp.column(name, table=pivot.alias), name, copy=False) 627 for name in pivot_columns 628 if name not in columns_to_exclude 629 ) 630 continue 631 632 for name in columns: 633 if name in columns_to_exclude or name in coalesced_columns: 634 continue 635 if name in using_column_tables and table in using_column_tables[name]: 636 coalesced_columns.add(name) 637 tables = using_column_tables[name] 638 coalesce_args = [exp.column(name, table=table) for table in tables] 639 640 new_selections.append( 641 alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) 642 ) 643 else: 644 alias_ = renamed_columns.get(name, name) 645 selection_expr = replaced_columns.get(name) or exp.column(name, table=table) 646 new_selections.append( 647 alias(selection_expr, alias_, copy=False) 648 if alias_ != name 649 else selection_expr 650 ) 651 652 # Ensures we don't overwrite the initial selections with an empty list 653 if new_selections and isinstance(scope.expression, exp.Select): 654 scope.expression.set("expressions", new_selections) 655 656 657def _add_except_columns( 658 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 659) -> None: 660 except_ = expression.args.get("except") 661 662 if not except_: 663 return 664 665 columns = {e.name for e in except_} 666 667 for table in tables: 668 except_columns[id(table)] = columns 669 670 671def _add_rename_columns( 672 expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]] 673) -> None: 674 rename = expression.args.get("rename") 675 676 if not rename: 677 return 678 679 columns = {e.this.name: e.alias for e in rename} 680 681 for table in tables: 682 rename_columns[id(table)] = columns 683 684 685def _add_replace_columns( 686 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] 687) -> None: 688 replace = expression.args.get("replace") 689 690 if not replace: 691 return 692 693 columns = {e.alias: e for e in replace} 694 695 for table in tables: 696 replace_columns[id(table)] = columns 697 698 699def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 700 """Ensure all output columns are aliased""" 701 if isinstance(scope_or_expression, exp.Expression): 702 scope = build_scope(scope_or_expression) 703 if not isinstance(scope, Scope): 704 return 705 else: 706 scope = scope_or_expression 707 708 new_selections = [] 709 for i, (selection, aliased_column) in enumerate( 710 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 711 ): 712 if selection is None: 713 break 714 715 if isinstance(selection, exp.Subquery): 716 if not selection.output_name: 717 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 718 elif not isinstance(selection, exp.Alias) and not selection.is_star: 719 selection = alias( 720 selection, 721 alias=selection.output_name or f"_col_{i}", 722 copy=False, 723 ) 724 if aliased_column: 725 selection.set("alias", exp.to_identifier(aliased_column)) 726 727 new_selections.append(selection) 728 729 if isinstance(scope.expression, exp.Select): 730 scope.expression.set("expressions", new_selections) 731 732 733def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 734 """Makes sure all identifiers that need to be quoted are quoted.""" 735 return expression.transform( 736 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 737 ) # type: ignore 738 739 740def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 741 """ 742 Pushes down the CTE alias columns into the projection, 743 744 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 745 746 Example: 747 >>> import sqlglot 748 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 749 >>> pushdown_cte_alias_columns(expression).sql() 750 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 751 752 Args: 753 expression: Expression to pushdown. 754 755 Returns: 756 The expression with the CTE aliases pushed down into the projection. 757 """ 758 for cte in expression.find_all(exp.CTE): 759 if cte.alias_column_names: 760 new_expressions = [] 761 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 762 if isinstance(projection, exp.Alias): 763 projection.set("alias", _alias) 764 else: 765 projection = alias(projection, alias=_alias) 766 new_expressions.append(projection) 767 cte.this.set("expressions", new_expressions) 768 769 return expression 770 771 772class Resolver: 773 """ 774 Helper for resolving columns. 775 776 This is a class so we can lazily load some things and easily share them across functions. 777 """ 778 779 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 780 self.scope = scope 781 self.schema = schema 782 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 783 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 784 self._all_columns: t.Optional[t.Set[str]] = None 785 self._infer_schema = infer_schema 786 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 787 788 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 789 """ 790 Get the table for a column name. 791 792 Args: 793 column_name: The column name to find the table for. 794 Returns: 795 The table name if it can be found/inferred. 796 """ 797 if self._unambiguous_columns is None: 798 self._unambiguous_columns = self._get_unambiguous_columns( 799 self._get_all_source_columns() 800 ) 801 802 table_name = self._unambiguous_columns.get(column_name) 803 804 if not table_name and self._infer_schema: 805 sources_without_schema = tuple( 806 source 807 for source, columns in self._get_all_source_columns().items() 808 if not columns or "*" in columns 809 ) 810 if len(sources_without_schema) == 1: 811 table_name = sources_without_schema[0] 812 813 if table_name not in self.scope.selected_sources: 814 return exp.to_identifier(table_name) 815 816 node, _ = self.scope.selected_sources.get(table_name) 817 818 if isinstance(node, exp.Query): 819 while node and node.alias != table_name: 820 node = node.parent 821 822 node_alias = node.args.get("alias") 823 if node_alias: 824 return exp.to_identifier(node_alias.this) 825 826 return exp.to_identifier(table_name) 827 828 @property 829 def all_columns(self) -> t.Set[str]: 830 """All available columns of all sources in this scope""" 831 if self._all_columns is None: 832 self._all_columns = { 833 column for columns in self._get_all_source_columns().values() for column in columns 834 } 835 return self._all_columns 836 837 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 838 """Resolve the source columns for a given source `name`.""" 839 cache_key = (name, only_visible) 840 if cache_key not in self._get_source_columns_cache: 841 if name not in self.scope.sources: 842 raise OptimizeError(f"Unknown table: {name}") 843 844 source = self.scope.sources[name] 845 846 if isinstance(source, exp.Table): 847 columns = self.schema.column_names(source, only_visible) 848 elif isinstance(source, Scope) and isinstance( 849 source.expression, (exp.Values, exp.Unnest) 850 ): 851 columns = source.expression.named_selects 852 853 # in bigquery, unnest structs are automatically scoped as tables, so you can 854 # directly select a struct field in a query. 855 # this handles the case where the unnest is statically defined. 856 if self.schema.dialect == "bigquery": 857 if source.expression.is_type(exp.DataType.Type.STRUCT): 858 for k in source.expression.type.expressions: # type: ignore 859 columns.append(k.name) 860 else: 861 columns = source.expression.named_selects 862 863 node, _ = self.scope.selected_sources.get(name) or (None, None) 864 if isinstance(node, Scope): 865 column_aliases = node.expression.alias_column_names 866 elif isinstance(node, exp.Expression): 867 column_aliases = node.alias_column_names 868 else: 869 column_aliases = [] 870 871 if column_aliases: 872 # If the source's columns are aliased, their aliases shadow the corresponding column names. 873 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 874 columns = [ 875 alias or name 876 for (name, alias) in itertools.zip_longest(columns, column_aliases) 877 ] 878 879 self._get_source_columns_cache[cache_key] = columns 880 881 return self._get_source_columns_cache[cache_key] 882 883 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 884 if self._source_columns is None: 885 self._source_columns = { 886 source_name: self.get_source_columns(source_name) 887 for source_name, source in itertools.chain( 888 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 889 ) 890 } 891 return self._source_columns 892 893 def _get_unambiguous_columns( 894 self, source_columns: t.Dict[str, t.Sequence[str]] 895 ) -> t.Mapping[str, str]: 896 """ 897 Find all the unambiguous columns in sources. 898 899 Args: 900 source_columns: Mapping of names to source columns. 901 902 Returns: 903 Mapping of column name to source name. 904 """ 905 if not source_columns: 906 return {} 907 908 source_columns_pairs = list(source_columns.items()) 909 910 first_table, first_columns = source_columns_pairs[0] 911 912 if len(source_columns_pairs) == 1: 913 # Performance optimization - avoid copying first_columns if there is only one table. 914 return SingleValuedMapping(first_columns, first_table) 915 916 unambiguous_columns = {col: first_table for col in first_columns} 917 all_columns = set(unambiguous_columns) 918 919 for table, columns in source_columns_pairs[1:]: 920 unique = set(columns) 921 ambiguous = all_columns.intersection(unique) 922 all_columns.update(columns) 923 924 for column in ambiguous: 925 unambiguous_columns.pop(column, None) 926 for column in unique.difference(ambiguous): 927 unambiguous_columns[column] = table 928 929 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26 allow_partial_qualification: bool = False, 27) -> exp.Expression: 28 """ 29 Rewrite sqlglot AST to have fully qualified columns. 30 31 Example: 32 >>> import sqlglot 33 >>> schema = {"tbl": {"col": "INT"}} 34 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 35 >>> qualify_columns(expression, schema).sql() 36 'SELECT tbl.col AS col FROM tbl' 37 38 Args: 39 expression: Expression to qualify. 40 schema: Database schema. 41 expand_alias_refs: Whether to expand references to aliases. 42 expand_stars: Whether to expand star queries. This is a necessary step 43 for most of the optimizer's rules to work; do not set to False unless you 44 know what you're doing! 45 infer_schema: Whether to infer the schema if missing. 46 allow_partial_qualification: Whether to allow partial qualification. 47 48 Returns: 49 The qualified expression. 50 51 Notes: 52 - Currently only handles a single PIVOT or UNPIVOT operator 53 """ 54 schema = ensure_schema(schema) 55 annotator = TypeAnnotator(schema) 56 infer_schema = schema.empty if infer_schema is None else infer_schema 57 dialect = Dialect.get_or_raise(schema.dialect) 58 pseudocolumns = dialect.PSEUDOCOLUMNS 59 60 for scope in traverse_scope(expression): 61 resolver = Resolver(scope, schema, infer_schema=infer_schema) 62 _pop_table_column_aliases(scope.ctes) 63 _pop_table_column_aliases(scope.derived_tables) 64 using_column_tables = _expand_using(scope, resolver) 65 66 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 67 _expand_alias_refs( 68 scope, 69 resolver, 70 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 71 ) 72 73 _convert_columns_to_dots(scope, resolver) 74 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 75 76 if not schema.empty and expand_alias_refs: 77 _expand_alias_refs(scope, resolver) 78 79 if not isinstance(scope.expression, exp.UDTF): 80 if expand_stars: 81 _expand_stars( 82 scope, 83 resolver, 84 using_column_tables, 85 pseudocolumns, 86 annotator, 87 ) 88 qualify_outputs(scope) 89 90 _expand_group_by(scope, dialect) 91 _expand_order_by(scope, resolver) 92 93 if dialect == "bigquery": 94 annotator.annotate_scope(scope) 95 96 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
- allow_partial_qualification: Whether to allow partial qualification.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
99def validate_qualify_columns(expression: E) -> E: 100 """Raise an `OptimizeError` if any columns aren't qualified""" 101 all_unqualified_columns = [] 102 for scope in traverse_scope(expression): 103 if isinstance(scope.expression, exp.Select): 104 unqualified_columns = scope.unqualified_columns 105 106 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 107 column = scope.external_columns[0] 108 for_table = f" for table: '{column.table}'" if column.table else "" 109 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 110 111 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 112 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 113 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 114 # this list here to ensure those in the former category will be excluded. 115 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 116 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 117 118 all_unqualified_columns.extend(unqualified_columns) 119 120 if all_unqualified_columns: 121 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 122 123 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
700def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 701 """Ensure all output columns are aliased""" 702 if isinstance(scope_or_expression, exp.Expression): 703 scope = build_scope(scope_or_expression) 704 if not isinstance(scope, Scope): 705 return 706 else: 707 scope = scope_or_expression 708 709 new_selections = [] 710 for i, (selection, aliased_column) in enumerate( 711 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 712 ): 713 if selection is None: 714 break 715 716 if isinstance(selection, exp.Subquery): 717 if not selection.output_name: 718 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 719 elif not isinstance(selection, exp.Alias) and not selection.is_star: 720 selection = alias( 721 selection, 722 alias=selection.output_name or f"_col_{i}", 723 copy=False, 724 ) 725 if aliased_column: 726 selection.set("alias", exp.to_identifier(aliased_column)) 727 728 new_selections.append(selection) 729 730 if isinstance(scope.expression, exp.Select): 731 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
734def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 735 """Makes sure all identifiers that need to be quoted are quoted.""" 736 return expression.transform( 737 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 738 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
741def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 742 """ 743 Pushes down the CTE alias columns into the projection, 744 745 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 746 747 Example: 748 >>> import sqlglot 749 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 750 >>> pushdown_cte_alias_columns(expression).sql() 751 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 752 753 Args: 754 expression: Expression to pushdown. 755 756 Returns: 757 The expression with the CTE aliases pushed down into the projection. 758 """ 759 for cte in expression.find_all(exp.CTE): 760 if cte.alias_column_names: 761 new_expressions = [] 762 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 763 if isinstance(projection, exp.Alias): 764 projection.set("alias", _alias) 765 else: 766 projection = alias(projection, alias=_alias) 767 new_expressions.append(projection) 768 cte.this.set("expressions", new_expressions) 769 770 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
773class Resolver: 774 """ 775 Helper for resolving columns. 776 777 This is a class so we can lazily load some things and easily share them across functions. 778 """ 779 780 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 781 self.scope = scope 782 self.schema = schema 783 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 784 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 785 self._all_columns: t.Optional[t.Set[str]] = None 786 self._infer_schema = infer_schema 787 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 788 789 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 790 """ 791 Get the table for a column name. 792 793 Args: 794 column_name: The column name to find the table for. 795 Returns: 796 The table name if it can be found/inferred. 797 """ 798 if self._unambiguous_columns is None: 799 self._unambiguous_columns = self._get_unambiguous_columns( 800 self._get_all_source_columns() 801 ) 802 803 table_name = self._unambiguous_columns.get(column_name) 804 805 if not table_name and self._infer_schema: 806 sources_without_schema = tuple( 807 source 808 for source, columns in self._get_all_source_columns().items() 809 if not columns or "*" in columns 810 ) 811 if len(sources_without_schema) == 1: 812 table_name = sources_without_schema[0] 813 814 if table_name not in self.scope.selected_sources: 815 return exp.to_identifier(table_name) 816 817 node, _ = self.scope.selected_sources.get(table_name) 818 819 if isinstance(node, exp.Query): 820 while node and node.alias != table_name: 821 node = node.parent 822 823 node_alias = node.args.get("alias") 824 if node_alias: 825 return exp.to_identifier(node_alias.this) 826 827 return exp.to_identifier(table_name) 828 829 @property 830 def all_columns(self) -> t.Set[str]: 831 """All available columns of all sources in this scope""" 832 if self._all_columns is None: 833 self._all_columns = { 834 column for columns in self._get_all_source_columns().values() for column in columns 835 } 836 return self._all_columns 837 838 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 839 """Resolve the source columns for a given source `name`.""" 840 cache_key = (name, only_visible) 841 if cache_key not in self._get_source_columns_cache: 842 if name not in self.scope.sources: 843 raise OptimizeError(f"Unknown table: {name}") 844 845 source = self.scope.sources[name] 846 847 if isinstance(source, exp.Table): 848 columns = self.schema.column_names(source, only_visible) 849 elif isinstance(source, Scope) and isinstance( 850 source.expression, (exp.Values, exp.Unnest) 851 ): 852 columns = source.expression.named_selects 853 854 # in bigquery, unnest structs are automatically scoped as tables, so you can 855 # directly select a struct field in a query. 856 # this handles the case where the unnest is statically defined. 857 if self.schema.dialect == "bigquery": 858 if source.expression.is_type(exp.DataType.Type.STRUCT): 859 for k in source.expression.type.expressions: # type: ignore 860 columns.append(k.name) 861 else: 862 columns = source.expression.named_selects 863 864 node, _ = self.scope.selected_sources.get(name) or (None, None) 865 if isinstance(node, Scope): 866 column_aliases = node.expression.alias_column_names 867 elif isinstance(node, exp.Expression): 868 column_aliases = node.alias_column_names 869 else: 870 column_aliases = [] 871 872 if column_aliases: 873 # If the source's columns are aliased, their aliases shadow the corresponding column names. 874 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 875 columns = [ 876 alias or name 877 for (name, alias) in itertools.zip_longest(columns, column_aliases) 878 ] 879 880 self._get_source_columns_cache[cache_key] = columns 881 882 return self._get_source_columns_cache[cache_key] 883 884 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 885 if self._source_columns is None: 886 self._source_columns = { 887 source_name: self.get_source_columns(source_name) 888 for source_name, source in itertools.chain( 889 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 890 ) 891 } 892 return self._source_columns 893 894 def _get_unambiguous_columns( 895 self, source_columns: t.Dict[str, t.Sequence[str]] 896 ) -> t.Mapping[str, str]: 897 """ 898 Find all the unambiguous columns in sources. 899 900 Args: 901 source_columns: Mapping of names to source columns. 902 903 Returns: 904 Mapping of column name to source name. 905 """ 906 if not source_columns: 907 return {} 908 909 source_columns_pairs = list(source_columns.items()) 910 911 first_table, first_columns = source_columns_pairs[0] 912 913 if len(source_columns_pairs) == 1: 914 # Performance optimization - avoid copying first_columns if there is only one table. 915 return SingleValuedMapping(first_columns, first_table) 916 917 unambiguous_columns = {col: first_table for col in first_columns} 918 all_columns = set(unambiguous_columns) 919 920 for table, columns in source_columns_pairs[1:]: 921 unique = set(columns) 922 ambiguous = all_columns.intersection(unique) 923 all_columns.update(columns) 924 925 for column in ambiguous: 926 unambiguous_columns.pop(column, None) 927 for column in unique.difference(ambiguous): 928 unambiguous_columns[column] = table 929 930 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
780 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 781 self.scope = scope 782 self.schema = schema 783 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 784 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 785 self._all_columns: t.Optional[t.Set[str]] = None 786 self._infer_schema = infer_schema 787 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
789 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 790 """ 791 Get the table for a column name. 792 793 Args: 794 column_name: The column name to find the table for. 795 Returns: 796 The table name if it can be found/inferred. 797 """ 798 if self._unambiguous_columns is None: 799 self._unambiguous_columns = self._get_unambiguous_columns( 800 self._get_all_source_columns() 801 ) 802 803 table_name = self._unambiguous_columns.get(column_name) 804 805 if not table_name and self._infer_schema: 806 sources_without_schema = tuple( 807 source 808 for source, columns in self._get_all_source_columns().items() 809 if not columns or "*" in columns 810 ) 811 if len(sources_without_schema) == 1: 812 table_name = sources_without_schema[0] 813 814 if table_name not in self.scope.selected_sources: 815 return exp.to_identifier(table_name) 816 817 node, _ = self.scope.selected_sources.get(table_name) 818 819 if isinstance(node, exp.Query): 820 while node and node.alias != table_name: 821 node = node.parent 822 823 node_alias = node.args.get("alias") 824 if node_alias: 825 return exp.to_identifier(node_alias.this) 826 827 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
829 @property 830 def all_columns(self) -> t.Set[str]: 831 """All available columns of all sources in this scope""" 832 if self._all_columns is None: 833 self._all_columns = { 834 column for columns in self._get_all_source_columns().values() for column in columns 835 } 836 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
838 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 839 """Resolve the source columns for a given source `name`.""" 840 cache_key = (name, only_visible) 841 if cache_key not in self._get_source_columns_cache: 842 if name not in self.scope.sources: 843 raise OptimizeError(f"Unknown table: {name}") 844 845 source = self.scope.sources[name] 846 847 if isinstance(source, exp.Table): 848 columns = self.schema.column_names(source, only_visible) 849 elif isinstance(source, Scope) and isinstance( 850 source.expression, (exp.Values, exp.Unnest) 851 ): 852 columns = source.expression.named_selects 853 854 # in bigquery, unnest structs are automatically scoped as tables, so you can 855 # directly select a struct field in a query. 856 # this handles the case where the unnest is statically defined. 857 if self.schema.dialect == "bigquery": 858 if source.expression.is_type(exp.DataType.Type.STRUCT): 859 for k in source.expression.type.expressions: # type: ignore 860 columns.append(k.name) 861 else: 862 columns = source.expression.named_selects 863 864 node, _ = self.scope.selected_sources.get(name) or (None, None) 865 if isinstance(node, Scope): 866 column_aliases = node.expression.alias_column_names 867 elif isinstance(node, exp.Expression): 868 column_aliases = node.alias_column_names 869 else: 870 column_aliases = [] 871 872 if column_aliases: 873 # If the source's columns are aliased, their aliases shadow the corresponding column names. 874 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 875 columns = [ 876 alias or name 877 for (name, alias) in itertools.zip_longest(columns, column_aliases) 878 ] 879 880 self._get_source_columns_cache[cache_key] = columns 881 882 return self._get_source_columns_cache[cache_key]
Resolve the source columns for a given source name
.