sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25 allow_partial_qualification: bool = False, 26) -> exp.Expression: 27 """ 28 Rewrite sqlglot AST to have fully qualified columns. 29 30 Example: 31 >>> import sqlglot 32 >>> schema = {"tbl": {"col": "INT"}} 33 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 34 >>> qualify_columns(expression, schema).sql() 35 'SELECT tbl.col AS col FROM tbl' 36 37 Args: 38 expression: Expression to qualify. 39 schema: Database schema. 40 expand_alias_refs: Whether to expand references to aliases. 41 expand_stars: Whether to expand star queries. This is a necessary step 42 for most of the optimizer's rules to work; do not set to False unless you 43 know what you're doing! 44 infer_schema: Whether to infer the schema if missing. 45 allow_partial_qualification: Whether to allow partial qualification. 46 47 Returns: 48 The qualified expression. 49 50 Notes: 51 - Currently only handles a single PIVOT or UNPIVOT operator 52 """ 53 schema = ensure_schema(schema) 54 annotator = TypeAnnotator(schema) 55 infer_schema = schema.empty if infer_schema is None else infer_schema 56 dialect = Dialect.get_or_raise(schema.dialect) 57 pseudocolumns = dialect.PSEUDOCOLUMNS 58 59 for scope in traverse_scope(expression): 60 resolver = Resolver(scope, schema, infer_schema=infer_schema) 61 _pop_table_column_aliases(scope.ctes) 62 _pop_table_column_aliases(scope.derived_tables) 63 using_column_tables = _expand_using(scope, resolver) 64 65 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 66 _expand_alias_refs( 67 scope, 68 resolver, 69 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 70 ) 71 72 _convert_columns_to_dots(scope, resolver) 73 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 74 75 if not schema.empty and expand_alias_refs: 76 _expand_alias_refs(scope, resolver) 77 78 if not isinstance(scope.expression, exp.UDTF): 79 if expand_stars: 80 _expand_stars( 81 scope, 82 resolver, 83 using_column_tables, 84 pseudocolumns, 85 annotator, 86 ) 87 qualify_outputs(scope) 88 89 _expand_group_by(scope, dialect) 90 _expand_order_by(scope, resolver) 91 92 if dialect == "bigquery": 93 annotator.annotate_scope(scope) 94 95 return expression 96 97 98def validate_qualify_columns(expression: E) -> E: 99 """Raise an `OptimizeError` if any columns aren't qualified""" 100 all_unqualified_columns = [] 101 for scope in traverse_scope(expression): 102 if isinstance(scope.expression, exp.Select): 103 unqualified_columns = scope.unqualified_columns 104 105 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 106 column = scope.external_columns[0] 107 for_table = f" for table: '{column.table}'" if column.table else "" 108 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 109 110 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 111 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 112 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 113 # this list here to ensure those in the former category will be excluded. 114 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 115 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 116 117 all_unqualified_columns.extend(unqualified_columns) 118 119 if all_unqualified_columns: 120 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 121 122 return expression 123 124 125def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 126 name_column = [] 127 field = unpivot.args.get("field") 128 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 129 name_column.append(field.this) 130 131 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 132 return itertools.chain(name_column, value_columns) 133 134 135def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 136 """ 137 Remove table column aliases. 138 139 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 140 """ 141 for derived_table in derived_tables: 142 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 143 continue 144 table_alias = derived_table.args.get("alias") 145 if table_alias: 146 table_alias.args.pop("columns", None) 147 148 149def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 150 columns = {} 151 152 def _update_source_columns(source_name: str) -> None: 153 for column_name in resolver.get_source_columns(source_name): 154 if column_name not in columns: 155 columns[column_name] = source_name 156 157 joins = list(scope.find_all(exp.Join)) 158 names = {join.alias_or_name for join in joins} 159 ordered = [key for key in scope.selected_sources if key not in names] 160 161 # Mapping of automatically joined column names to an ordered set of source names (dict). 162 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 163 164 for source_name in ordered: 165 _update_source_columns(source_name) 166 167 for i, join in enumerate(joins): 168 source_table = ordered[-1] 169 if source_table: 170 _update_source_columns(source_table) 171 172 join_table = join.alias_or_name 173 ordered.append(join_table) 174 175 using = join.args.get("using") 176 if not using: 177 continue 178 179 join_columns = resolver.get_source_columns(join_table) 180 conditions = [] 181 using_identifier_count = len(using) 182 183 for identifier in using: 184 identifier = identifier.name 185 table = columns.get(identifier) 186 187 if not table or identifier not in join_columns: 188 if (columns and "*" not in columns) and join_columns: 189 raise OptimizeError(f"Cannot automatically join: {identifier}") 190 191 table = table or source_table 192 193 if i == 0 or using_identifier_count == 1: 194 lhs: exp.Expression = exp.column(identifier, table=table) 195 else: 196 coalesce_columns = [ 197 exp.column(identifier, table=t) 198 for t in ordered[:-1] 199 if identifier in resolver.get_source_columns(t) 200 ] 201 if len(coalesce_columns) > 1: 202 lhs = exp.func("coalesce", *coalesce_columns) 203 else: 204 lhs = exp.column(identifier, table=table) 205 206 conditions.append(lhs.eq(exp.column(identifier, table=join_table))) 207 208 # Set all values in the dict to None, because we only care about the key ordering 209 tables = column_tables.setdefault(identifier, {}) 210 if table not in tables: 211 tables[table] = None 212 if join_table not in tables: 213 tables[join_table] = None 214 215 join.args.pop("using") 216 join.set("on", exp.and_(*conditions, copy=False)) 217 218 if column_tables: 219 for column in scope.columns: 220 if not column.table and column.name in column_tables: 221 tables = column_tables[column.name] 222 coalesce_args = [exp.column(column.name, table=table) for table in tables] 223 replacement = exp.func("coalesce", *coalesce_args) 224 225 # Ensure selects keep their output name 226 if isinstance(column.parent, exp.Select): 227 replacement = alias(replacement, alias=column.name, copy=False) 228 229 scope.replace(column, replacement) 230 231 return column_tables 232 233 234def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None: 235 expression = scope.expression 236 237 if not isinstance(expression, exp.Select): 238 return 239 240 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 241 242 def replace_columns( 243 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 244 ) -> None: 245 is_group_by = isinstance(node, exp.Group) 246 if not node or (expand_only_groupby and not is_group_by): 247 return 248 249 for column in walk_in_scope(node, prune=lambda node: node.is_star): 250 if not isinstance(column, exp.Column): 251 continue 252 253 # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g: 254 # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded 255 # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col)) 256 # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns 257 if expand_only_groupby and is_group_by and column.parent is not node: 258 continue 259 260 table = resolver.get_table(column.name) if resolve_table and not column.table else None 261 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 262 double_agg = ( 263 ( 264 alias_expr.find(exp.AggFunc) 265 and ( 266 column.find_ancestor(exp.AggFunc) 267 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 268 ) 269 ) 270 if alias_expr 271 else False 272 ) 273 274 if table and (not alias_expr or double_agg): 275 column.set("table", table) 276 elif not column.table and alias_expr and not double_agg: 277 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 278 if literal_index: 279 column.replace(exp.Literal.number(i)) 280 else: 281 column = column.replace(exp.paren(alias_expr)) 282 simplified = simplify_parens(column) 283 if simplified is not column: 284 column.replace(simplified) 285 286 for i, projection in enumerate(expression.selects): 287 replace_columns(projection) 288 if isinstance(projection, exp.Alias): 289 alias_to_expression[projection.alias] = (projection.this, i + 1) 290 291 parent_scope = scope 292 while parent_scope.is_union: 293 parent_scope = parent_scope.parent 294 295 # We shouldn't expand aliases if they match the recursive CTE's columns 296 if parent_scope.is_cte: 297 cte = parent_scope.expression.parent 298 if cte.find_ancestor(exp.With).recursive: 299 for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: 300 alias_to_expression.pop(recursive_cte_column.output_name, None) 301 302 replace_columns(expression.args.get("where")) 303 replace_columns(expression.args.get("group"), literal_index=True) 304 replace_columns(expression.args.get("having"), resolve_table=True) 305 replace_columns(expression.args.get("qualify"), resolve_table=True) 306 307 scope.clear_cache() 308 309 310def _expand_group_by(scope: Scope, dialect: DialectType) -> None: 311 expression = scope.expression 312 group = expression.args.get("group") 313 if not group: 314 return 315 316 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 317 expression.set("group", group) 318 319 320def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 321 order = scope.expression.args.get("order") 322 if not order: 323 return 324 325 ordereds = order.expressions 326 for ordered, new_expression in zip( 327 ordereds, 328 _expand_positional_references( 329 scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True 330 ), 331 ): 332 for agg in ordered.find_all(exp.AggFunc): 333 for col in agg.find_all(exp.Column): 334 if not col.table: 335 col.set("table", resolver.get_table(col.name)) 336 337 ordered.set("this", new_expression) 338 339 if scope.expression.args.get("group"): 340 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 341 342 for ordered in ordereds: 343 ordered = ordered.this 344 345 ordered.replace( 346 exp.to_identifier(_select_by_pos(scope, ordered).alias) 347 if ordered.is_int 348 else selects.get(ordered, ordered) 349 ) 350 351 352def _expand_positional_references( 353 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False 354) -> t.List[exp.Expression]: 355 new_nodes: t.List[exp.Expression] = [] 356 ambiguous_projections = None 357 358 for node in expressions: 359 if node.is_int: 360 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 361 362 if alias: 363 new_nodes.append(exp.column(select.args["alias"].copy())) 364 else: 365 select = select.this 366 367 if dialect == "bigquery": 368 if ambiguous_projections is None: 369 # When a projection name is also a source name and it is referenced in the 370 # GROUP BY clause, BQ can't understand what the identifier corresponds to 371 ambiguous_projections = { 372 s.alias_or_name 373 for s in scope.expression.selects 374 if s.alias_or_name in scope.selected_sources 375 } 376 377 ambiguous = any( 378 column.parts[0].name in ambiguous_projections 379 for column in select.find_all(exp.Column) 380 ) 381 else: 382 ambiguous = False 383 384 if ( 385 isinstance(select, exp.CONSTANTS) 386 or select.find(exp.Explode, exp.Unnest) 387 or ambiguous 388 ): 389 new_nodes.append(node) 390 else: 391 new_nodes.append(select.copy()) 392 else: 393 new_nodes.append(node) 394 395 return new_nodes 396 397 398def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 399 try: 400 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 401 except IndexError: 402 raise OptimizeError(f"Unknown output column: {node.name}") 403 404 405def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 406 """ 407 Converts `Column` instances that represent struct field lookup into chained `Dots`. 408 409 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 410 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 411 """ 412 converted = False 413 for column in itertools.chain(scope.columns, scope.stars): 414 if isinstance(column, exp.Dot): 415 continue 416 417 column_table: t.Optional[str | exp.Identifier] = column.table 418 if ( 419 column_table 420 and column_table not in scope.sources 421 and ( 422 not scope.parent 423 or column_table not in scope.parent.sources 424 or not scope.is_correlated_subquery 425 ) 426 ): 427 root, *parts = column.parts 428 429 if root.name in scope.sources: 430 # The struct is already qualified, but we still need to change the AST 431 column_table = root 432 root, *parts = parts 433 else: 434 column_table = resolver.get_table(root.name) 435 436 if column_table: 437 converted = True 438 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 439 440 if converted: 441 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 442 # a `for column in scope.columns` iteration, even though they shouldn't be 443 scope.clear_cache() 444 445 446def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None: 447 """Disambiguate columns, ensuring each column specifies a source""" 448 for column in scope.columns: 449 column_table = column.table 450 column_name = column.name 451 452 if column_table and column_table in scope.sources: 453 source_columns = resolver.get_source_columns(column_table) 454 if ( 455 not allow_partial_qualification 456 and source_columns 457 and column_name not in source_columns 458 and "*" not in source_columns 459 ): 460 raise OptimizeError(f"Unknown column: {column_name}") 461 462 if not column_table: 463 if scope.pivots and not column.find_ancestor(exp.Pivot): 464 # If the column is under the Pivot expression, we need to qualify it 465 # using the name of the pivoted source instead of the pivot's alias 466 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 467 continue 468 469 # column_table can be a '' because bigquery unnest has no table alias 470 column_table = resolver.get_table(column_name) 471 if column_table: 472 column.set("table", column_table) 473 474 for pivot in scope.pivots: 475 for column in pivot.find_all(exp.Column): 476 if not column.table and column.name in resolver.all_columns: 477 column_table = resolver.get_table(column.name) 478 if column_table: 479 column.set("table", column_table) 480 481 482def _expand_struct_stars( 483 expression: exp.Dot, 484) -> t.List[exp.Alias]: 485 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 486 487 dot_column = t.cast(exp.Column, expression.find(exp.Column)) 488 if not dot_column.is_type(exp.DataType.Type.STRUCT): 489 return [] 490 491 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 492 dot_column = dot_column.copy() 493 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 494 495 # First part is the table name and last part is the star so they can be dropped 496 dot_parts = expression.parts[1:-1] 497 498 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 499 for part in dot_parts[1:]: 500 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 501 # Unable to expand star unless all fields are named 502 if not isinstance(field.this, exp.Identifier): 503 return [] 504 505 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 506 starting_struct = field 507 break 508 else: 509 # There is no matching field in the struct 510 return [] 511 512 taken_names = set() 513 new_selections = [] 514 515 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 516 name = field.name 517 518 # Ambiguous or anonymous fields can't be expanded 519 if name in taken_names or not isinstance(field.this, exp.Identifier): 520 return [] 521 522 taken_names.add(name) 523 524 this = field.this.copy() 525 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 526 new_column = exp.column( 527 t.cast(exp.Identifier, root), 528 table=dot_column.args.get("table"), 529 fields=t.cast(t.List[exp.Identifier], parts), 530 ) 531 new_selections.append(alias(new_column, this, copy=False)) 532 533 return new_selections 534 535 536def _expand_stars( 537 scope: Scope, 538 resolver: Resolver, 539 using_column_tables: t.Dict[str, t.Any], 540 pseudocolumns: t.Set[str], 541 annotator: TypeAnnotator, 542) -> None: 543 """Expand stars to lists of column selections""" 544 545 new_selections: t.List[exp.Expression] = [] 546 except_columns: t.Dict[int, t.Set[str]] = {} 547 replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {} 548 rename_columns: t.Dict[int, t.Dict[str, str]] = {} 549 550 coalesced_columns = set() 551 dialect = resolver.schema.dialect 552 553 pivot_output_columns = None 554 pivot_exclude_columns = None 555 556 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 557 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 558 if pivot.unpivot: 559 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 560 561 field = pivot.args.get("field") 562 if isinstance(field, exp.In): 563 pivot_exclude_columns = { 564 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 565 } 566 else: 567 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 568 569 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 570 if not pivot_output_columns: 571 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 572 573 is_bigquery = dialect == "bigquery" 574 if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars): 575 # Found struct expansion, annotate scope ahead of time 576 annotator.annotate_scope(scope) 577 578 for expression in scope.expression.selects: 579 tables = [] 580 if isinstance(expression, exp.Star): 581 tables.extend(scope.selected_sources) 582 _add_except_columns(expression, tables, except_columns) 583 _add_replace_columns(expression, tables, replace_columns) 584 _add_rename_columns(expression, tables, rename_columns) 585 elif expression.is_star: 586 if not isinstance(expression, exp.Dot): 587 tables.append(expression.table) 588 _add_except_columns(expression.this, tables, except_columns) 589 _add_replace_columns(expression.this, tables, replace_columns) 590 _add_rename_columns(expression.this, tables, rename_columns) 591 elif is_bigquery: 592 struct_fields = _expand_struct_stars(expression) 593 if struct_fields: 594 new_selections.extend(struct_fields) 595 continue 596 597 if not tables: 598 new_selections.append(expression) 599 continue 600 601 for table in tables: 602 if table not in scope.sources: 603 raise OptimizeError(f"Unknown table: {table}") 604 605 columns = resolver.get_source_columns(table, only_visible=True) 606 columns = columns or scope.outer_columns 607 608 if pseudocolumns: 609 columns = [name for name in columns if name.upper() not in pseudocolumns] 610 611 if not columns or "*" in columns: 612 return 613 614 table_id = id(table) 615 columns_to_exclude = except_columns.get(table_id) or set() 616 renamed_columns = rename_columns.get(table_id, {}) 617 replaced_columns = replace_columns.get(table_id, {}) 618 619 if pivot: 620 if pivot_output_columns and pivot_exclude_columns: 621 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 622 pivot_columns.extend(pivot_output_columns) 623 else: 624 pivot_columns = pivot.alias_column_names 625 626 if pivot_columns: 627 new_selections.extend( 628 alias(exp.column(name, table=pivot.alias), name, copy=False) 629 for name in pivot_columns 630 if name not in columns_to_exclude 631 ) 632 continue 633 634 for name in columns: 635 if name in columns_to_exclude or name in coalesced_columns: 636 continue 637 if name in using_column_tables and table in using_column_tables[name]: 638 coalesced_columns.add(name) 639 tables = using_column_tables[name] 640 coalesce_args = [exp.column(name, table=table) for table in tables] 641 642 new_selections.append( 643 alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) 644 ) 645 else: 646 alias_ = renamed_columns.get(name, name) 647 selection_expr = replaced_columns.get(name) or exp.column(name, table=table) 648 new_selections.append( 649 alias(selection_expr, alias_, copy=False) 650 if alias_ != name 651 else selection_expr 652 ) 653 654 # Ensures we don't overwrite the initial selections with an empty list 655 if new_selections and isinstance(scope.expression, exp.Select): 656 scope.expression.set("expressions", new_selections) 657 658 659def _add_except_columns( 660 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 661) -> None: 662 except_ = expression.args.get("except") 663 664 if not except_: 665 return 666 667 columns = {e.name for e in except_} 668 669 for table in tables: 670 except_columns[id(table)] = columns 671 672 673def _add_rename_columns( 674 expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]] 675) -> None: 676 rename = expression.args.get("rename") 677 678 if not rename: 679 return 680 681 columns = {e.this.name: e.alias for e in rename} 682 683 for table in tables: 684 rename_columns[id(table)] = columns 685 686 687def _add_replace_columns( 688 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] 689) -> None: 690 replace = expression.args.get("replace") 691 692 if not replace: 693 return 694 695 columns = {e.alias: e for e in replace} 696 697 for table in tables: 698 replace_columns[id(table)] = columns 699 700 701def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 702 """Ensure all output columns are aliased""" 703 if isinstance(scope_or_expression, exp.Expression): 704 scope = build_scope(scope_or_expression) 705 if not isinstance(scope, Scope): 706 return 707 else: 708 scope = scope_or_expression 709 710 new_selections = [] 711 for i, (selection, aliased_column) in enumerate( 712 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 713 ): 714 if selection is None: 715 break 716 717 if isinstance(selection, exp.Subquery): 718 if not selection.output_name: 719 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 720 elif not isinstance(selection, exp.Alias) and not selection.is_star: 721 selection = alias( 722 selection, 723 alias=selection.output_name or f"_col_{i}", 724 copy=False, 725 ) 726 if aliased_column: 727 selection.set("alias", exp.to_identifier(aliased_column)) 728 729 new_selections.append(selection) 730 731 if isinstance(scope.expression, exp.Select): 732 scope.expression.set("expressions", new_selections) 733 734 735def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 736 """Makes sure all identifiers that need to be quoted are quoted.""" 737 return expression.transform( 738 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 739 ) # type: ignore 740 741 742def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 743 """ 744 Pushes down the CTE alias columns into the projection, 745 746 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 747 748 Example: 749 >>> import sqlglot 750 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 751 >>> pushdown_cte_alias_columns(expression).sql() 752 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 753 754 Args: 755 expression: Expression to pushdown. 756 757 Returns: 758 The expression with the CTE aliases pushed down into the projection. 759 """ 760 for cte in expression.find_all(exp.CTE): 761 if cte.alias_column_names: 762 new_expressions = [] 763 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 764 if isinstance(projection, exp.Alias): 765 projection.set("alias", _alias) 766 else: 767 projection = alias(projection, alias=_alias) 768 new_expressions.append(projection) 769 cte.this.set("expressions", new_expressions) 770 771 return expression 772 773 774class Resolver: 775 """ 776 Helper for resolving columns. 777 778 This is a class so we can lazily load some things and easily share them across functions. 779 """ 780 781 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 782 self.scope = scope 783 self.schema = schema 784 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 785 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 786 self._all_columns: t.Optional[t.Set[str]] = None 787 self._infer_schema = infer_schema 788 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 789 790 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 791 """ 792 Get the table for a column name. 793 794 Args: 795 column_name: The column name to find the table for. 796 Returns: 797 The table name if it can be found/inferred. 798 """ 799 if self._unambiguous_columns is None: 800 self._unambiguous_columns = self._get_unambiguous_columns( 801 self._get_all_source_columns() 802 ) 803 804 table_name = self._unambiguous_columns.get(column_name) 805 806 if not table_name and self._infer_schema: 807 sources_without_schema = tuple( 808 source 809 for source, columns in self._get_all_source_columns().items() 810 if not columns or "*" in columns 811 ) 812 if len(sources_without_schema) == 1: 813 table_name = sources_without_schema[0] 814 815 if table_name not in self.scope.selected_sources: 816 return exp.to_identifier(table_name) 817 818 node, _ = self.scope.selected_sources.get(table_name) 819 820 if isinstance(node, exp.Query): 821 while node and node.alias != table_name: 822 node = node.parent 823 824 node_alias = node.args.get("alias") 825 if node_alias: 826 return exp.to_identifier(node_alias.this) 827 828 return exp.to_identifier(table_name) 829 830 @property 831 def all_columns(self) -> t.Set[str]: 832 """All available columns of all sources in this scope""" 833 if self._all_columns is None: 834 self._all_columns = { 835 column for columns in self._get_all_source_columns().values() for column in columns 836 } 837 return self._all_columns 838 839 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 840 """Resolve the source columns for a given source `name`.""" 841 cache_key = (name, only_visible) 842 if cache_key not in self._get_source_columns_cache: 843 if name not in self.scope.sources: 844 raise OptimizeError(f"Unknown table: {name}") 845 846 source = self.scope.sources[name] 847 848 if isinstance(source, exp.Table): 849 columns = self.schema.column_names(source, only_visible) 850 elif isinstance(source, Scope) and isinstance( 851 source.expression, (exp.Values, exp.Unnest) 852 ): 853 columns = source.expression.named_selects 854 855 # in bigquery, unnest structs are automatically scoped as tables, so you can 856 # directly select a struct field in a query. 857 # this handles the case where the unnest is statically defined. 858 if self.schema.dialect == "bigquery": 859 if source.expression.is_type(exp.DataType.Type.STRUCT): 860 for k in source.expression.type.expressions: # type: ignore 861 columns.append(k.name) 862 else: 863 columns = source.expression.named_selects 864 865 node, _ = self.scope.selected_sources.get(name) or (None, None) 866 if isinstance(node, Scope): 867 column_aliases = node.expression.alias_column_names 868 elif isinstance(node, exp.Expression): 869 column_aliases = node.alias_column_names 870 else: 871 column_aliases = [] 872 873 if column_aliases: 874 # If the source's columns are aliased, their aliases shadow the corresponding column names. 875 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 876 columns = [ 877 alias or name 878 for (name, alias) in itertools.zip_longest(columns, column_aliases) 879 ] 880 881 self._get_source_columns_cache[cache_key] = columns 882 883 return self._get_source_columns_cache[cache_key] 884 885 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 886 if self._source_columns is None: 887 self._source_columns = { 888 source_name: self.get_source_columns(source_name) 889 for source_name, source in itertools.chain( 890 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 891 ) 892 } 893 return self._source_columns 894 895 def _get_unambiguous_columns( 896 self, source_columns: t.Dict[str, t.Sequence[str]] 897 ) -> t.Mapping[str, str]: 898 """ 899 Find all the unambiguous columns in sources. 900 901 Args: 902 source_columns: Mapping of names to source columns. 903 904 Returns: 905 Mapping of column name to source name. 906 """ 907 if not source_columns: 908 return {} 909 910 source_columns_pairs = list(source_columns.items()) 911 912 first_table, first_columns = source_columns_pairs[0] 913 914 if len(source_columns_pairs) == 1: 915 # Performance optimization - avoid copying first_columns if there is only one table. 916 return SingleValuedMapping(first_columns, first_table) 917 918 unambiguous_columns = {col: first_table for col in first_columns} 919 all_columns = set(unambiguous_columns) 920 921 for table, columns in source_columns_pairs[1:]: 922 unique = set(columns) 923 ambiguous = all_columns.intersection(unique) 924 all_columns.update(columns) 925 926 for column in ambiguous: 927 unambiguous_columns.pop(column, None) 928 for column in unique.difference(ambiguous): 929 unambiguous_columns[column] = table 930 931 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26 allow_partial_qualification: bool = False, 27) -> exp.Expression: 28 """ 29 Rewrite sqlglot AST to have fully qualified columns. 30 31 Example: 32 >>> import sqlglot 33 >>> schema = {"tbl": {"col": "INT"}} 34 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 35 >>> qualify_columns(expression, schema).sql() 36 'SELECT tbl.col AS col FROM tbl' 37 38 Args: 39 expression: Expression to qualify. 40 schema: Database schema. 41 expand_alias_refs: Whether to expand references to aliases. 42 expand_stars: Whether to expand star queries. This is a necessary step 43 for most of the optimizer's rules to work; do not set to False unless you 44 know what you're doing! 45 infer_schema: Whether to infer the schema if missing. 46 allow_partial_qualification: Whether to allow partial qualification. 47 48 Returns: 49 The qualified expression. 50 51 Notes: 52 - Currently only handles a single PIVOT or UNPIVOT operator 53 """ 54 schema = ensure_schema(schema) 55 annotator = TypeAnnotator(schema) 56 infer_schema = schema.empty if infer_schema is None else infer_schema 57 dialect = Dialect.get_or_raise(schema.dialect) 58 pseudocolumns = dialect.PSEUDOCOLUMNS 59 60 for scope in traverse_scope(expression): 61 resolver = Resolver(scope, schema, infer_schema=infer_schema) 62 _pop_table_column_aliases(scope.ctes) 63 _pop_table_column_aliases(scope.derived_tables) 64 using_column_tables = _expand_using(scope, resolver) 65 66 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 67 _expand_alias_refs( 68 scope, 69 resolver, 70 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 71 ) 72 73 _convert_columns_to_dots(scope, resolver) 74 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 75 76 if not schema.empty and expand_alias_refs: 77 _expand_alias_refs(scope, resolver) 78 79 if not isinstance(scope.expression, exp.UDTF): 80 if expand_stars: 81 _expand_stars( 82 scope, 83 resolver, 84 using_column_tables, 85 pseudocolumns, 86 annotator, 87 ) 88 qualify_outputs(scope) 89 90 _expand_group_by(scope, dialect) 91 _expand_order_by(scope, resolver) 92 93 if dialect == "bigquery": 94 annotator.annotate_scope(scope) 95 96 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
- allow_partial_qualification: Whether to allow partial qualification.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
99def validate_qualify_columns(expression: E) -> E: 100 """Raise an `OptimizeError` if any columns aren't qualified""" 101 all_unqualified_columns = [] 102 for scope in traverse_scope(expression): 103 if isinstance(scope.expression, exp.Select): 104 unqualified_columns = scope.unqualified_columns 105 106 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 107 column = scope.external_columns[0] 108 for_table = f" for table: '{column.table}'" if column.table else "" 109 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 110 111 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 112 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 113 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 114 # this list here to ensure those in the former category will be excluded. 115 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 116 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 117 118 all_unqualified_columns.extend(unqualified_columns) 119 120 if all_unqualified_columns: 121 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 122 123 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
702def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 703 """Ensure all output columns are aliased""" 704 if isinstance(scope_or_expression, exp.Expression): 705 scope = build_scope(scope_or_expression) 706 if not isinstance(scope, Scope): 707 return 708 else: 709 scope = scope_or_expression 710 711 new_selections = [] 712 for i, (selection, aliased_column) in enumerate( 713 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 714 ): 715 if selection is None: 716 break 717 718 if isinstance(selection, exp.Subquery): 719 if not selection.output_name: 720 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 721 elif not isinstance(selection, exp.Alias) and not selection.is_star: 722 selection = alias( 723 selection, 724 alias=selection.output_name or f"_col_{i}", 725 copy=False, 726 ) 727 if aliased_column: 728 selection.set("alias", exp.to_identifier(aliased_column)) 729 730 new_selections.append(selection) 731 732 if isinstance(scope.expression, exp.Select): 733 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
736def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 737 """Makes sure all identifiers that need to be quoted are quoted.""" 738 return expression.transform( 739 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 740 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
743def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 744 """ 745 Pushes down the CTE alias columns into the projection, 746 747 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 748 749 Example: 750 >>> import sqlglot 751 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 752 >>> pushdown_cte_alias_columns(expression).sql() 753 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 754 755 Args: 756 expression: Expression to pushdown. 757 758 Returns: 759 The expression with the CTE aliases pushed down into the projection. 760 """ 761 for cte in expression.find_all(exp.CTE): 762 if cte.alias_column_names: 763 new_expressions = [] 764 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 765 if isinstance(projection, exp.Alias): 766 projection.set("alias", _alias) 767 else: 768 projection = alias(projection, alias=_alias) 769 new_expressions.append(projection) 770 cte.this.set("expressions", new_expressions) 771 772 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
775class Resolver: 776 """ 777 Helper for resolving columns. 778 779 This is a class so we can lazily load some things and easily share them across functions. 780 """ 781 782 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 783 self.scope = scope 784 self.schema = schema 785 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 786 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 787 self._all_columns: t.Optional[t.Set[str]] = None 788 self._infer_schema = infer_schema 789 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 790 791 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 792 """ 793 Get the table for a column name. 794 795 Args: 796 column_name: The column name to find the table for. 797 Returns: 798 The table name if it can be found/inferred. 799 """ 800 if self._unambiguous_columns is None: 801 self._unambiguous_columns = self._get_unambiguous_columns( 802 self._get_all_source_columns() 803 ) 804 805 table_name = self._unambiguous_columns.get(column_name) 806 807 if not table_name and self._infer_schema: 808 sources_without_schema = tuple( 809 source 810 for source, columns in self._get_all_source_columns().items() 811 if not columns or "*" in columns 812 ) 813 if len(sources_without_schema) == 1: 814 table_name = sources_without_schema[0] 815 816 if table_name not in self.scope.selected_sources: 817 return exp.to_identifier(table_name) 818 819 node, _ = self.scope.selected_sources.get(table_name) 820 821 if isinstance(node, exp.Query): 822 while node and node.alias != table_name: 823 node = node.parent 824 825 node_alias = node.args.get("alias") 826 if node_alias: 827 return exp.to_identifier(node_alias.this) 828 829 return exp.to_identifier(table_name) 830 831 @property 832 def all_columns(self) -> t.Set[str]: 833 """All available columns of all sources in this scope""" 834 if self._all_columns is None: 835 self._all_columns = { 836 column for columns in self._get_all_source_columns().values() for column in columns 837 } 838 return self._all_columns 839 840 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 841 """Resolve the source columns for a given source `name`.""" 842 cache_key = (name, only_visible) 843 if cache_key not in self._get_source_columns_cache: 844 if name not in self.scope.sources: 845 raise OptimizeError(f"Unknown table: {name}") 846 847 source = self.scope.sources[name] 848 849 if isinstance(source, exp.Table): 850 columns = self.schema.column_names(source, only_visible) 851 elif isinstance(source, Scope) and isinstance( 852 source.expression, (exp.Values, exp.Unnest) 853 ): 854 columns = source.expression.named_selects 855 856 # in bigquery, unnest structs are automatically scoped as tables, so you can 857 # directly select a struct field in a query. 858 # this handles the case where the unnest is statically defined. 859 if self.schema.dialect == "bigquery": 860 if source.expression.is_type(exp.DataType.Type.STRUCT): 861 for k in source.expression.type.expressions: # type: ignore 862 columns.append(k.name) 863 else: 864 columns = source.expression.named_selects 865 866 node, _ = self.scope.selected_sources.get(name) or (None, None) 867 if isinstance(node, Scope): 868 column_aliases = node.expression.alias_column_names 869 elif isinstance(node, exp.Expression): 870 column_aliases = node.alias_column_names 871 else: 872 column_aliases = [] 873 874 if column_aliases: 875 # If the source's columns are aliased, their aliases shadow the corresponding column names. 876 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 877 columns = [ 878 alias or name 879 for (name, alias) in itertools.zip_longest(columns, column_aliases) 880 ] 881 882 self._get_source_columns_cache[cache_key] = columns 883 884 return self._get_source_columns_cache[cache_key] 885 886 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 887 if self._source_columns is None: 888 self._source_columns = { 889 source_name: self.get_source_columns(source_name) 890 for source_name, source in itertools.chain( 891 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 892 ) 893 } 894 return self._source_columns 895 896 def _get_unambiguous_columns( 897 self, source_columns: t.Dict[str, t.Sequence[str]] 898 ) -> t.Mapping[str, str]: 899 """ 900 Find all the unambiguous columns in sources. 901 902 Args: 903 source_columns: Mapping of names to source columns. 904 905 Returns: 906 Mapping of column name to source name. 907 """ 908 if not source_columns: 909 return {} 910 911 source_columns_pairs = list(source_columns.items()) 912 913 first_table, first_columns = source_columns_pairs[0] 914 915 if len(source_columns_pairs) == 1: 916 # Performance optimization - avoid copying first_columns if there is only one table. 917 return SingleValuedMapping(first_columns, first_table) 918 919 unambiguous_columns = {col: first_table for col in first_columns} 920 all_columns = set(unambiguous_columns) 921 922 for table, columns in source_columns_pairs[1:]: 923 unique = set(columns) 924 ambiguous = all_columns.intersection(unique) 925 all_columns.update(columns) 926 927 for column in ambiguous: 928 unambiguous_columns.pop(column, None) 929 for column in unique.difference(ambiguous): 930 unambiguous_columns[column] = table 931 932 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
782 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 783 self.scope = scope 784 self.schema = schema 785 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 786 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 787 self._all_columns: t.Optional[t.Set[str]] = None 788 self._infer_schema = infer_schema 789 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
791 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 792 """ 793 Get the table for a column name. 794 795 Args: 796 column_name: The column name to find the table for. 797 Returns: 798 The table name if it can be found/inferred. 799 """ 800 if self._unambiguous_columns is None: 801 self._unambiguous_columns = self._get_unambiguous_columns( 802 self._get_all_source_columns() 803 ) 804 805 table_name = self._unambiguous_columns.get(column_name) 806 807 if not table_name and self._infer_schema: 808 sources_without_schema = tuple( 809 source 810 for source, columns in self._get_all_source_columns().items() 811 if not columns or "*" in columns 812 ) 813 if len(sources_without_schema) == 1: 814 table_name = sources_without_schema[0] 815 816 if table_name not in self.scope.selected_sources: 817 return exp.to_identifier(table_name) 818 819 node, _ = self.scope.selected_sources.get(table_name) 820 821 if isinstance(node, exp.Query): 822 while node and node.alias != table_name: 823 node = node.parent 824 825 node_alias = node.args.get("alias") 826 if node_alias: 827 return exp.to_identifier(node_alias.this) 828 829 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
831 @property 832 def all_columns(self) -> t.Set[str]: 833 """All available columns of all sources in this scope""" 834 if self._all_columns is None: 835 self._all_columns = { 836 column for columns in self._get_all_source_columns().values() for column in columns 837 } 838 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
840 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 841 """Resolve the source columns for a given source `name`.""" 842 cache_key = (name, only_visible) 843 if cache_key not in self._get_source_columns_cache: 844 if name not in self.scope.sources: 845 raise OptimizeError(f"Unknown table: {name}") 846 847 source = self.scope.sources[name] 848 849 if isinstance(source, exp.Table): 850 columns = self.schema.column_names(source, only_visible) 851 elif isinstance(source, Scope) and isinstance( 852 source.expression, (exp.Values, exp.Unnest) 853 ): 854 columns = source.expression.named_selects 855 856 # in bigquery, unnest structs are automatically scoped as tables, so you can 857 # directly select a struct field in a query. 858 # this handles the case where the unnest is statically defined. 859 if self.schema.dialect == "bigquery": 860 if source.expression.is_type(exp.DataType.Type.STRUCT): 861 for k in source.expression.type.expressions: # type: ignore 862 columns.append(k.name) 863 else: 864 columns = source.expression.named_selects 865 866 node, _ = self.scope.selected_sources.get(name) or (None, None) 867 if isinstance(node, Scope): 868 column_aliases = node.expression.alias_column_names 869 elif isinstance(node, exp.Expression): 870 column_aliases = node.alias_column_names 871 else: 872 column_aliases = [] 873 874 if column_aliases: 875 # If the source's columns are aliased, their aliases shadow the corresponding column names. 876 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 877 columns = [ 878 alias or name 879 for (name, alias) in itertools.zip_longest(columns, column_aliases) 880 ] 881 882 self._get_source_columns_cache[cache_key] = columns 883 884 return self._get_source_columns_cache[cache_key]
Resolve the source columns for a given source name
.