sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25 allow_partial_qualification: bool = False, 26) -> exp.Expression: 27 """ 28 Rewrite sqlglot AST to have fully qualified columns. 29 30 Example: 31 >>> import sqlglot 32 >>> schema = {"tbl": {"col": "INT"}} 33 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 34 >>> qualify_columns(expression, schema).sql() 35 'SELECT tbl.col AS col FROM tbl' 36 37 Args: 38 expression: Expression to qualify. 39 schema: Database schema. 40 expand_alias_refs: Whether to expand references to aliases. 41 expand_stars: Whether to expand star queries. This is a necessary step 42 for most of the optimizer's rules to work; do not set to False unless you 43 know what you're doing! 44 infer_schema: Whether to infer the schema if missing. 45 allow_partial_qualification: Whether to allow partial qualification. 46 47 Returns: 48 The qualified expression. 49 50 Notes: 51 - Currently only handles a single PIVOT or UNPIVOT operator 52 """ 53 schema = ensure_schema(schema) 54 annotator = TypeAnnotator(schema) 55 infer_schema = schema.empty if infer_schema is None else infer_schema 56 dialect = Dialect.get_or_raise(schema.dialect) 57 pseudocolumns = dialect.PSEUDOCOLUMNS 58 59 for scope in traverse_scope(expression): 60 resolver = Resolver(scope, schema, infer_schema=infer_schema) 61 _pop_table_column_aliases(scope.ctes) 62 _pop_table_column_aliases(scope.derived_tables) 63 using_column_tables = _expand_using(scope, resolver) 64 65 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 66 _expand_alias_refs( 67 scope, 68 resolver, 69 dialect, 70 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 71 ) 72 73 _convert_columns_to_dots(scope, resolver) 74 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 75 76 if not schema.empty and expand_alias_refs: 77 _expand_alias_refs(scope, resolver, dialect) 78 79 if isinstance(scope.expression, exp.Select): 80 if expand_stars: 81 _expand_stars( 82 scope, 83 resolver, 84 using_column_tables, 85 pseudocolumns, 86 annotator, 87 ) 88 qualify_outputs(scope) 89 90 _expand_group_by(scope, dialect) 91 _expand_order_by(scope, resolver) 92 93 if dialect == "bigquery": 94 annotator.annotate_scope(scope) 95 96 return expression 97 98 99def validate_qualify_columns(expression: E) -> E: 100 """Raise an `OptimizeError` if any columns aren't qualified""" 101 all_unqualified_columns = [] 102 for scope in traverse_scope(expression): 103 if isinstance(scope.expression, exp.Select): 104 unqualified_columns = scope.unqualified_columns 105 106 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 107 column = scope.external_columns[0] 108 for_table = f" for table: '{column.table}'" if column.table else "" 109 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 110 111 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 112 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 113 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 114 # this list here to ensure those in the former category will be excluded. 115 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 116 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 117 118 all_unqualified_columns.extend(unqualified_columns) 119 120 if all_unqualified_columns: 121 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 122 123 return expression 124 125 126def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 127 name_column = [] 128 field = unpivot.args.get("field") 129 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 130 name_column.append(field.this) 131 132 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 133 return itertools.chain(name_column, value_columns) 134 135 136def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 137 """ 138 Remove table column aliases. 139 140 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 141 """ 142 for derived_table in derived_tables: 143 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 144 continue 145 table_alias = derived_table.args.get("alias") 146 if table_alias: 147 table_alias.args.pop("columns", None) 148 149 150def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 151 columns = {} 152 153 def _update_source_columns(source_name: str) -> None: 154 for column_name in resolver.get_source_columns(source_name): 155 if column_name not in columns: 156 columns[column_name] = source_name 157 158 joins = list(scope.find_all(exp.Join)) 159 names = {join.alias_or_name for join in joins} 160 ordered = [key for key in scope.selected_sources if key not in names] 161 162 # Mapping of automatically joined column names to an ordered set of source names (dict). 163 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 164 165 for source_name in ordered: 166 _update_source_columns(source_name) 167 168 for i, join in enumerate(joins): 169 source_table = ordered[-1] 170 if source_table: 171 _update_source_columns(source_table) 172 173 join_table = join.alias_or_name 174 ordered.append(join_table) 175 176 using = join.args.get("using") 177 if not using: 178 continue 179 180 join_columns = resolver.get_source_columns(join_table) 181 conditions = [] 182 using_identifier_count = len(using) 183 184 for identifier in using: 185 identifier = identifier.name 186 table = columns.get(identifier) 187 188 if not table or identifier not in join_columns: 189 if (columns and "*" not in columns) and join_columns: 190 raise OptimizeError(f"Cannot automatically join: {identifier}") 191 192 table = table or source_table 193 194 if i == 0 or using_identifier_count == 1: 195 lhs: exp.Expression = exp.column(identifier, table=table) 196 else: 197 coalesce_columns = [ 198 exp.column(identifier, table=t) 199 for t in ordered[:-1] 200 if identifier in resolver.get_source_columns(t) 201 ] 202 if len(coalesce_columns) > 1: 203 lhs = exp.func("coalesce", *coalesce_columns) 204 else: 205 lhs = exp.column(identifier, table=table) 206 207 conditions.append(lhs.eq(exp.column(identifier, table=join_table))) 208 209 # Set all values in the dict to None, because we only care about the key ordering 210 tables = column_tables.setdefault(identifier, {}) 211 if table not in tables: 212 tables[table] = None 213 if join_table not in tables: 214 tables[join_table] = None 215 216 join.args.pop("using") 217 join.set("on", exp.and_(*conditions, copy=False)) 218 219 if column_tables: 220 for column in scope.columns: 221 if not column.table and column.name in column_tables: 222 tables = column_tables[column.name] 223 coalesce_args = [exp.column(column.name, table=table) for table in tables] 224 replacement: exp.Expression = exp.func("coalesce", *coalesce_args) 225 226 if isinstance(column.parent, exp.Select): 227 # Ensure the USING column keeps its name if it's projected 228 replacement = alias(replacement, alias=column.name, copy=False) 229 elif isinstance(column.parent, exp.Struct): 230 # Ensure the USING column keeps its name if it's an anonymous STRUCT field 231 replacement = exp.PropertyEQ( 232 this=exp.to_identifier(column.name), expression=replacement 233 ) 234 235 scope.replace(column, replacement) 236 237 return column_tables 238 239 240def _expand_alias_refs( 241 scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False 242) -> None: 243 """ 244 Expand references to aliases. 245 Example: 246 SELECT y.foo AS bar, bar * 2 AS baz FROM y 247 => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y 248 """ 249 expression = scope.expression 250 251 if not isinstance(expression, exp.Select): 252 return 253 254 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 255 256 def replace_columns( 257 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 258 ) -> None: 259 is_group_by = isinstance(node, exp.Group) 260 if not node or (expand_only_groupby and not is_group_by): 261 return 262 263 for column in walk_in_scope(node, prune=lambda node: node.is_star): 264 if not isinstance(column, exp.Column): 265 continue 266 267 # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g: 268 # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded 269 # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col)) 270 # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns 271 if expand_only_groupby and is_group_by and column.parent is not node: 272 continue 273 274 table = resolver.get_table(column.name) if resolve_table and not column.table else None 275 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 276 double_agg = ( 277 ( 278 alias_expr.find(exp.AggFunc) 279 and ( 280 column.find_ancestor(exp.AggFunc) 281 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 282 ) 283 ) 284 if alias_expr 285 else False 286 ) 287 288 if table and (not alias_expr or double_agg): 289 column.set("table", table) 290 elif not column.table and alias_expr and not double_agg: 291 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 292 if literal_index: 293 column.replace(exp.Literal.number(i)) 294 else: 295 column = column.replace(exp.paren(alias_expr)) 296 simplified = simplify_parens(column) 297 if simplified is not column: 298 column.replace(simplified) 299 300 for i, projection in enumerate(expression.selects): 301 replace_columns(projection) 302 if isinstance(projection, exp.Alias): 303 alias_to_expression[projection.alias] = (projection.this, i + 1) 304 305 parent_scope = scope 306 while parent_scope.is_union: 307 parent_scope = parent_scope.parent 308 309 # We shouldn't expand aliases if they match the recursive CTE's columns 310 if parent_scope.is_cte: 311 cte = parent_scope.expression.parent 312 if cte.find_ancestor(exp.With).recursive: 313 for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: 314 alias_to_expression.pop(recursive_cte_column.output_name, None) 315 316 replace_columns(expression.args.get("where")) 317 replace_columns(expression.args.get("group"), literal_index=True) 318 replace_columns(expression.args.get("having"), resolve_table=True) 319 replace_columns(expression.args.get("qualify"), resolve_table=True) 320 321 # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else) 322 # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes 323 if dialect == "snowflake": 324 for join in expression.args.get("joins") or []: 325 replace_columns(join) 326 327 scope.clear_cache() 328 329 330def _expand_group_by(scope: Scope, dialect: DialectType) -> None: 331 expression = scope.expression 332 group = expression.args.get("group") 333 if not group: 334 return 335 336 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 337 expression.set("group", group) 338 339 340def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 341 order = scope.expression.args.get("order") 342 if not order: 343 return 344 345 ordereds = order.expressions 346 for ordered, new_expression in zip( 347 ordereds, 348 _expand_positional_references( 349 scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True 350 ), 351 ): 352 for agg in ordered.find_all(exp.AggFunc): 353 for col in agg.find_all(exp.Column): 354 if not col.table: 355 col.set("table", resolver.get_table(col.name)) 356 357 ordered.set("this", new_expression) 358 359 if scope.expression.args.get("group"): 360 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 361 362 for ordered in ordereds: 363 ordered = ordered.this 364 365 ordered.replace( 366 exp.to_identifier(_select_by_pos(scope, ordered).alias) 367 if ordered.is_int 368 else selects.get(ordered, ordered) 369 ) 370 371 372def _expand_positional_references( 373 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False 374) -> t.List[exp.Expression]: 375 new_nodes: t.List[exp.Expression] = [] 376 ambiguous_projections = None 377 378 for node in expressions: 379 if node.is_int: 380 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 381 382 if alias: 383 new_nodes.append(exp.column(select.args["alias"].copy())) 384 else: 385 select = select.this 386 387 if dialect == "bigquery": 388 if ambiguous_projections is None: 389 # When a projection name is also a source name and it is referenced in the 390 # GROUP BY clause, BQ can't understand what the identifier corresponds to 391 ambiguous_projections = { 392 s.alias_or_name 393 for s in scope.expression.selects 394 if s.alias_or_name in scope.selected_sources 395 } 396 397 ambiguous = any( 398 column.parts[0].name in ambiguous_projections 399 for column in select.find_all(exp.Column) 400 ) 401 else: 402 ambiguous = False 403 404 if ( 405 isinstance(select, exp.CONSTANTS) 406 or select.find(exp.Explode, exp.Unnest) 407 or ambiguous 408 ): 409 new_nodes.append(node) 410 else: 411 new_nodes.append(select.copy()) 412 else: 413 new_nodes.append(node) 414 415 return new_nodes 416 417 418def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 419 try: 420 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 421 except IndexError: 422 raise OptimizeError(f"Unknown output column: {node.name}") 423 424 425def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 426 """ 427 Converts `Column` instances that represent struct field lookup into chained `Dots`. 428 429 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 430 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 431 """ 432 converted = False 433 for column in itertools.chain(scope.columns, scope.stars): 434 if isinstance(column, exp.Dot): 435 continue 436 437 column_table: t.Optional[str | exp.Identifier] = column.table 438 if ( 439 column_table 440 and column_table not in scope.sources 441 and ( 442 not scope.parent 443 or column_table not in scope.parent.sources 444 or not scope.is_correlated_subquery 445 ) 446 ): 447 root, *parts = column.parts 448 449 if root.name in scope.sources: 450 # The struct is already qualified, but we still need to change the AST 451 column_table = root 452 root, *parts = parts 453 else: 454 column_table = resolver.get_table(root.name) 455 456 if column_table: 457 converted = True 458 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 459 460 if converted: 461 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 462 # a `for column in scope.columns` iteration, even though they shouldn't be 463 scope.clear_cache() 464 465 466def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None: 467 """Disambiguate columns, ensuring each column specifies a source""" 468 for column in scope.columns: 469 column_table = column.table 470 column_name = column.name 471 472 if column_table and column_table in scope.sources: 473 source_columns = resolver.get_source_columns(column_table) 474 if ( 475 not allow_partial_qualification 476 and source_columns 477 and column_name not in source_columns 478 and "*" not in source_columns 479 ): 480 raise OptimizeError(f"Unknown column: {column_name}") 481 482 if not column_table: 483 if scope.pivots and not column.find_ancestor(exp.Pivot): 484 # If the column is under the Pivot expression, we need to qualify it 485 # using the name of the pivoted source instead of the pivot's alias 486 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 487 continue 488 489 # column_table can be a '' because bigquery unnest has no table alias 490 column_table = resolver.get_table(column_name) 491 if column_table: 492 column.set("table", column_table) 493 494 for pivot in scope.pivots: 495 for column in pivot.find_all(exp.Column): 496 if not column.table and column.name in resolver.all_columns: 497 column_table = resolver.get_table(column.name) 498 if column_table: 499 column.set("table", column_table) 500 501 502def _expand_struct_stars( 503 expression: exp.Dot, 504) -> t.List[exp.Alias]: 505 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 506 507 dot_column = t.cast(exp.Column, expression.find(exp.Column)) 508 if not dot_column.is_type(exp.DataType.Type.STRUCT): 509 return [] 510 511 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 512 dot_column = dot_column.copy() 513 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 514 515 # First part is the table name and last part is the star so they can be dropped 516 dot_parts = expression.parts[1:-1] 517 518 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 519 for part in dot_parts[1:]: 520 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 521 # Unable to expand star unless all fields are named 522 if not isinstance(field.this, exp.Identifier): 523 return [] 524 525 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 526 starting_struct = field 527 break 528 else: 529 # There is no matching field in the struct 530 return [] 531 532 taken_names = set() 533 new_selections = [] 534 535 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 536 name = field.name 537 538 # Ambiguous or anonymous fields can't be expanded 539 if name in taken_names or not isinstance(field.this, exp.Identifier): 540 return [] 541 542 taken_names.add(name) 543 544 this = field.this.copy() 545 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 546 new_column = exp.column( 547 t.cast(exp.Identifier, root), 548 table=dot_column.args.get("table"), 549 fields=t.cast(t.List[exp.Identifier], parts), 550 ) 551 new_selections.append(alias(new_column, this, copy=False)) 552 553 return new_selections 554 555 556def _expand_stars( 557 scope: Scope, 558 resolver: Resolver, 559 using_column_tables: t.Dict[str, t.Any], 560 pseudocolumns: t.Set[str], 561 annotator: TypeAnnotator, 562) -> None: 563 """Expand stars to lists of column selections""" 564 565 new_selections: t.List[exp.Expression] = [] 566 except_columns: t.Dict[int, t.Set[str]] = {} 567 replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {} 568 rename_columns: t.Dict[int, t.Dict[str, str]] = {} 569 570 coalesced_columns = set() 571 dialect = resolver.schema.dialect 572 573 pivot_output_columns = None 574 pivot_exclude_columns = None 575 576 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 577 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 578 if pivot.unpivot: 579 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 580 581 field = pivot.args.get("field") 582 if isinstance(field, exp.In): 583 pivot_exclude_columns = { 584 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 585 } 586 else: 587 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 588 589 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 590 if not pivot_output_columns: 591 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 592 593 is_bigquery = dialect == "bigquery" 594 if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars): 595 # Found struct expansion, annotate scope ahead of time 596 annotator.annotate_scope(scope) 597 598 for expression in scope.expression.selects: 599 tables = [] 600 if isinstance(expression, exp.Star): 601 tables.extend(scope.selected_sources) 602 _add_except_columns(expression, tables, except_columns) 603 _add_replace_columns(expression, tables, replace_columns) 604 _add_rename_columns(expression, tables, rename_columns) 605 elif expression.is_star: 606 if not isinstance(expression, exp.Dot): 607 tables.append(expression.table) 608 _add_except_columns(expression.this, tables, except_columns) 609 _add_replace_columns(expression.this, tables, replace_columns) 610 _add_rename_columns(expression.this, tables, rename_columns) 611 elif is_bigquery: 612 struct_fields = _expand_struct_stars(expression) 613 if struct_fields: 614 new_selections.extend(struct_fields) 615 continue 616 617 if not tables: 618 new_selections.append(expression) 619 continue 620 621 for table in tables: 622 if table not in scope.sources: 623 raise OptimizeError(f"Unknown table: {table}") 624 625 columns = resolver.get_source_columns(table, only_visible=True) 626 columns = columns or scope.outer_columns 627 628 if pseudocolumns: 629 columns = [name for name in columns if name.upper() not in pseudocolumns] 630 631 if not columns or "*" in columns: 632 return 633 634 table_id = id(table) 635 columns_to_exclude = except_columns.get(table_id) or set() 636 renamed_columns = rename_columns.get(table_id, {}) 637 replaced_columns = replace_columns.get(table_id, {}) 638 639 if pivot: 640 if pivot_output_columns and pivot_exclude_columns: 641 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 642 pivot_columns.extend(pivot_output_columns) 643 else: 644 pivot_columns = pivot.alias_column_names 645 646 if pivot_columns: 647 new_selections.extend( 648 alias(exp.column(name, table=pivot.alias), name, copy=False) 649 for name in pivot_columns 650 if name not in columns_to_exclude 651 ) 652 continue 653 654 for name in columns: 655 if name in columns_to_exclude or name in coalesced_columns: 656 continue 657 if name in using_column_tables and table in using_column_tables[name]: 658 coalesced_columns.add(name) 659 tables = using_column_tables[name] 660 coalesce_args = [exp.column(name, table=table) for table in tables] 661 662 new_selections.append( 663 alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) 664 ) 665 else: 666 alias_ = renamed_columns.get(name, name) 667 selection_expr = replaced_columns.get(name) or exp.column(name, table=table) 668 new_selections.append( 669 alias(selection_expr, alias_, copy=False) 670 if alias_ != name 671 else selection_expr 672 ) 673 674 # Ensures we don't overwrite the initial selections with an empty list 675 if new_selections and isinstance(scope.expression, exp.Select): 676 scope.expression.set("expressions", new_selections) 677 678 679def _add_except_columns( 680 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 681) -> None: 682 except_ = expression.args.get("except") 683 684 if not except_: 685 return 686 687 columns = {e.name for e in except_} 688 689 for table in tables: 690 except_columns[id(table)] = columns 691 692 693def _add_rename_columns( 694 expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]] 695) -> None: 696 rename = expression.args.get("rename") 697 698 if not rename: 699 return 700 701 columns = {e.this.name: e.alias for e in rename} 702 703 for table in tables: 704 rename_columns[id(table)] = columns 705 706 707def _add_replace_columns( 708 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] 709) -> None: 710 replace = expression.args.get("replace") 711 712 if not replace: 713 return 714 715 columns = {e.alias: e for e in replace} 716 717 for table in tables: 718 replace_columns[id(table)] = columns 719 720 721def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 722 """Ensure all output columns are aliased""" 723 if isinstance(scope_or_expression, exp.Expression): 724 scope = build_scope(scope_or_expression) 725 if not isinstance(scope, Scope): 726 return 727 else: 728 scope = scope_or_expression 729 730 new_selections = [] 731 for i, (selection, aliased_column) in enumerate( 732 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 733 ): 734 if selection is None: 735 break 736 737 if isinstance(selection, exp.Subquery): 738 if not selection.output_name: 739 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 740 elif not isinstance(selection, exp.Alias) and not selection.is_star: 741 selection = alias( 742 selection, 743 alias=selection.output_name or f"_col_{i}", 744 copy=False, 745 ) 746 if aliased_column: 747 selection.set("alias", exp.to_identifier(aliased_column)) 748 749 new_selections.append(selection) 750 751 if isinstance(scope.expression, exp.Select): 752 scope.expression.set("expressions", new_selections) 753 754 755def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 756 """Makes sure all identifiers that need to be quoted are quoted.""" 757 return expression.transform( 758 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 759 ) # type: ignore 760 761 762def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 763 """ 764 Pushes down the CTE alias columns into the projection, 765 766 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 767 768 Example: 769 >>> import sqlglot 770 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 771 >>> pushdown_cte_alias_columns(expression).sql() 772 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 773 774 Args: 775 expression: Expression to pushdown. 776 777 Returns: 778 The expression with the CTE aliases pushed down into the projection. 779 """ 780 for cte in expression.find_all(exp.CTE): 781 if cte.alias_column_names: 782 new_expressions = [] 783 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 784 if isinstance(projection, exp.Alias): 785 projection.set("alias", _alias) 786 else: 787 projection = alias(projection, alias=_alias) 788 new_expressions.append(projection) 789 cte.this.set("expressions", new_expressions) 790 791 return expression 792 793 794class Resolver: 795 """ 796 Helper for resolving columns. 797 798 This is a class so we can lazily load some things and easily share them across functions. 799 """ 800 801 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 802 self.scope = scope 803 self.schema = schema 804 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 805 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 806 self._all_columns: t.Optional[t.Set[str]] = None 807 self._infer_schema = infer_schema 808 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 809 810 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 811 """ 812 Get the table for a column name. 813 814 Args: 815 column_name: The column name to find the table for. 816 Returns: 817 The table name if it can be found/inferred. 818 """ 819 if self._unambiguous_columns is None: 820 self._unambiguous_columns = self._get_unambiguous_columns( 821 self._get_all_source_columns() 822 ) 823 824 table_name = self._unambiguous_columns.get(column_name) 825 826 if not table_name and self._infer_schema: 827 sources_without_schema = tuple( 828 source 829 for source, columns in self._get_all_source_columns().items() 830 if not columns or "*" in columns 831 ) 832 if len(sources_without_schema) == 1: 833 table_name = sources_without_schema[0] 834 835 if table_name not in self.scope.selected_sources: 836 return exp.to_identifier(table_name) 837 838 node, _ = self.scope.selected_sources.get(table_name) 839 840 if isinstance(node, exp.Query): 841 while node and node.alias != table_name: 842 node = node.parent 843 844 node_alias = node.args.get("alias") 845 if node_alias: 846 return exp.to_identifier(node_alias.this) 847 848 return exp.to_identifier(table_name) 849 850 @property 851 def all_columns(self) -> t.Set[str]: 852 """All available columns of all sources in this scope""" 853 if self._all_columns is None: 854 self._all_columns = { 855 column for columns in self._get_all_source_columns().values() for column in columns 856 } 857 return self._all_columns 858 859 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 860 """Resolve the source columns for a given source `name`.""" 861 cache_key = (name, only_visible) 862 if cache_key not in self._get_source_columns_cache: 863 if name not in self.scope.sources: 864 raise OptimizeError(f"Unknown table: {name}") 865 866 source = self.scope.sources[name] 867 868 if isinstance(source, exp.Table): 869 columns = self.schema.column_names(source, only_visible) 870 elif isinstance(source, Scope) and isinstance( 871 source.expression, (exp.Values, exp.Unnest) 872 ): 873 columns = source.expression.named_selects 874 875 # in bigquery, unnest structs are automatically scoped as tables, so you can 876 # directly select a struct field in a query. 877 # this handles the case where the unnest is statically defined. 878 if self.schema.dialect == "bigquery": 879 if source.expression.is_type(exp.DataType.Type.STRUCT): 880 for k in source.expression.type.expressions: # type: ignore 881 columns.append(k.name) 882 else: 883 columns = source.expression.named_selects 884 885 node, _ = self.scope.selected_sources.get(name) or (None, None) 886 if isinstance(node, Scope): 887 column_aliases = node.expression.alias_column_names 888 elif isinstance(node, exp.Expression): 889 column_aliases = node.alias_column_names 890 else: 891 column_aliases = [] 892 893 if column_aliases: 894 # If the source's columns are aliased, their aliases shadow the corresponding column names. 895 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 896 columns = [ 897 alias or name 898 for (name, alias) in itertools.zip_longest(columns, column_aliases) 899 ] 900 901 pseudocolumns = self._get_source_pseudocolumns(name) 902 if pseudocolumns: 903 columns = list(columns) 904 columns.extend(c for c in pseudocolumns if c not in columns) 905 906 self._get_source_columns_cache[cache_key] = columns 907 908 return self._get_source_columns_cache[cache_key] 909 910 def _get_source_pseudocolumns(self, name: str) -> t.Sequence[str]: 911 if self.schema.dialect == "snowflake" and self.scope.expression.args.get("connect"): 912 # When there is a CONNECT BY clause, there is only one table being scanned 913 # See: https://docs.snowflake.com/en/sql-reference/constructs/connect-by 914 return ["LEVEL"] 915 return [] 916 917 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 918 if self._source_columns is None: 919 self._source_columns = { 920 source_name: self.get_source_columns(source_name) 921 for source_name, source in itertools.chain( 922 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 923 ) 924 } 925 return self._source_columns 926 927 def _get_unambiguous_columns( 928 self, source_columns: t.Dict[str, t.Sequence[str]] 929 ) -> t.Mapping[str, str]: 930 """ 931 Find all the unambiguous columns in sources. 932 933 Args: 934 source_columns: Mapping of names to source columns. 935 936 Returns: 937 Mapping of column name to source name. 938 """ 939 if not source_columns: 940 return {} 941 942 source_columns_pairs = list(source_columns.items()) 943 944 first_table, first_columns = source_columns_pairs[0] 945 946 if len(source_columns_pairs) == 1: 947 # Performance optimization - avoid copying first_columns if there is only one table. 948 return SingleValuedMapping(first_columns, first_table) 949 950 unambiguous_columns = {col: first_table for col in first_columns} 951 all_columns = set(unambiguous_columns) 952 953 for table, columns in source_columns_pairs[1:]: 954 unique = set(columns) 955 ambiguous = all_columns.intersection(unique) 956 all_columns.update(columns) 957 958 for column in ambiguous: 959 unambiguous_columns.pop(column, None) 960 for column in unique.difference(ambiguous): 961 unambiguous_columns[column] = table 962 963 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None, allow_partial_qualification: bool = False) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26 allow_partial_qualification: bool = False, 27) -> exp.Expression: 28 """ 29 Rewrite sqlglot AST to have fully qualified columns. 30 31 Example: 32 >>> import sqlglot 33 >>> schema = {"tbl": {"col": "INT"}} 34 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 35 >>> qualify_columns(expression, schema).sql() 36 'SELECT tbl.col AS col FROM tbl' 37 38 Args: 39 expression: Expression to qualify. 40 schema: Database schema. 41 expand_alias_refs: Whether to expand references to aliases. 42 expand_stars: Whether to expand star queries. This is a necessary step 43 for most of the optimizer's rules to work; do not set to False unless you 44 know what you're doing! 45 infer_schema: Whether to infer the schema if missing. 46 allow_partial_qualification: Whether to allow partial qualification. 47 48 Returns: 49 The qualified expression. 50 51 Notes: 52 - Currently only handles a single PIVOT or UNPIVOT operator 53 """ 54 schema = ensure_schema(schema) 55 annotator = TypeAnnotator(schema) 56 infer_schema = schema.empty if infer_schema is None else infer_schema 57 dialect = Dialect.get_or_raise(schema.dialect) 58 pseudocolumns = dialect.PSEUDOCOLUMNS 59 60 for scope in traverse_scope(expression): 61 resolver = Resolver(scope, schema, infer_schema=infer_schema) 62 _pop_table_column_aliases(scope.ctes) 63 _pop_table_column_aliases(scope.derived_tables) 64 using_column_tables = _expand_using(scope, resolver) 65 66 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 67 _expand_alias_refs( 68 scope, 69 resolver, 70 dialect, 71 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 72 ) 73 74 _convert_columns_to_dots(scope, resolver) 75 _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) 76 77 if not schema.empty and expand_alias_refs: 78 _expand_alias_refs(scope, resolver, dialect) 79 80 if isinstance(scope.expression, exp.Select): 81 if expand_stars: 82 _expand_stars( 83 scope, 84 resolver, 85 using_column_tables, 86 pseudocolumns, 87 annotator, 88 ) 89 qualify_outputs(scope) 90 91 _expand_group_by(scope, dialect) 92 _expand_order_by(scope, resolver) 93 94 if dialect == "bigquery": 95 annotator.annotate_scope(scope) 96 97 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
- allow_partial_qualification: Whether to allow partial qualification.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
100def validate_qualify_columns(expression: E) -> E: 101 """Raise an `OptimizeError` if any columns aren't qualified""" 102 all_unqualified_columns = [] 103 for scope in traverse_scope(expression): 104 if isinstance(scope.expression, exp.Select): 105 unqualified_columns = scope.unqualified_columns 106 107 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 108 column = scope.external_columns[0] 109 for_table = f" for table: '{column.table}'" if column.table else "" 110 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 111 112 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 113 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 114 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 115 # this list here to ensure those in the former category will be excluded. 116 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 117 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 118 119 all_unqualified_columns.extend(unqualified_columns) 120 121 if all_unqualified_columns: 122 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 123 124 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
722def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 723 """Ensure all output columns are aliased""" 724 if isinstance(scope_or_expression, exp.Expression): 725 scope = build_scope(scope_or_expression) 726 if not isinstance(scope, Scope): 727 return 728 else: 729 scope = scope_or_expression 730 731 new_selections = [] 732 for i, (selection, aliased_column) in enumerate( 733 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 734 ): 735 if selection is None: 736 break 737 738 if isinstance(selection, exp.Subquery): 739 if not selection.output_name: 740 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 741 elif not isinstance(selection, exp.Alias) and not selection.is_star: 742 selection = alias( 743 selection, 744 alias=selection.output_name or f"_col_{i}", 745 copy=False, 746 ) 747 if aliased_column: 748 selection.set("alias", exp.to_identifier(aliased_column)) 749 750 new_selections.append(selection) 751 752 if isinstance(scope.expression, exp.Select): 753 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
756def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 757 """Makes sure all identifiers that need to be quoted are quoted.""" 758 return expression.transform( 759 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 760 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
763def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 764 """ 765 Pushes down the CTE alias columns into the projection, 766 767 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 768 769 Example: 770 >>> import sqlglot 771 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 772 >>> pushdown_cte_alias_columns(expression).sql() 773 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 774 775 Args: 776 expression: Expression to pushdown. 777 778 Returns: 779 The expression with the CTE aliases pushed down into the projection. 780 """ 781 for cte in expression.find_all(exp.CTE): 782 if cte.alias_column_names: 783 new_expressions = [] 784 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 785 if isinstance(projection, exp.Alias): 786 projection.set("alias", _alias) 787 else: 788 projection = alias(projection, alias=_alias) 789 new_expressions.append(projection) 790 cte.this.set("expressions", new_expressions) 791 792 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
795class Resolver: 796 """ 797 Helper for resolving columns. 798 799 This is a class so we can lazily load some things and easily share them across functions. 800 """ 801 802 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 803 self.scope = scope 804 self.schema = schema 805 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 806 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 807 self._all_columns: t.Optional[t.Set[str]] = None 808 self._infer_schema = infer_schema 809 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 810 811 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 812 """ 813 Get the table for a column name. 814 815 Args: 816 column_name: The column name to find the table for. 817 Returns: 818 The table name if it can be found/inferred. 819 """ 820 if self._unambiguous_columns is None: 821 self._unambiguous_columns = self._get_unambiguous_columns( 822 self._get_all_source_columns() 823 ) 824 825 table_name = self._unambiguous_columns.get(column_name) 826 827 if not table_name and self._infer_schema: 828 sources_without_schema = tuple( 829 source 830 for source, columns in self._get_all_source_columns().items() 831 if not columns or "*" in columns 832 ) 833 if len(sources_without_schema) == 1: 834 table_name = sources_without_schema[0] 835 836 if table_name not in self.scope.selected_sources: 837 return exp.to_identifier(table_name) 838 839 node, _ = self.scope.selected_sources.get(table_name) 840 841 if isinstance(node, exp.Query): 842 while node and node.alias != table_name: 843 node = node.parent 844 845 node_alias = node.args.get("alias") 846 if node_alias: 847 return exp.to_identifier(node_alias.this) 848 849 return exp.to_identifier(table_name) 850 851 @property 852 def all_columns(self) -> t.Set[str]: 853 """All available columns of all sources in this scope""" 854 if self._all_columns is None: 855 self._all_columns = { 856 column for columns in self._get_all_source_columns().values() for column in columns 857 } 858 return self._all_columns 859 860 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 861 """Resolve the source columns for a given source `name`.""" 862 cache_key = (name, only_visible) 863 if cache_key not in self._get_source_columns_cache: 864 if name not in self.scope.sources: 865 raise OptimizeError(f"Unknown table: {name}") 866 867 source = self.scope.sources[name] 868 869 if isinstance(source, exp.Table): 870 columns = self.schema.column_names(source, only_visible) 871 elif isinstance(source, Scope) and isinstance( 872 source.expression, (exp.Values, exp.Unnest) 873 ): 874 columns = source.expression.named_selects 875 876 # in bigquery, unnest structs are automatically scoped as tables, so you can 877 # directly select a struct field in a query. 878 # this handles the case where the unnest is statically defined. 879 if self.schema.dialect == "bigquery": 880 if source.expression.is_type(exp.DataType.Type.STRUCT): 881 for k in source.expression.type.expressions: # type: ignore 882 columns.append(k.name) 883 else: 884 columns = source.expression.named_selects 885 886 node, _ = self.scope.selected_sources.get(name) or (None, None) 887 if isinstance(node, Scope): 888 column_aliases = node.expression.alias_column_names 889 elif isinstance(node, exp.Expression): 890 column_aliases = node.alias_column_names 891 else: 892 column_aliases = [] 893 894 if column_aliases: 895 # If the source's columns are aliased, their aliases shadow the corresponding column names. 896 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 897 columns = [ 898 alias or name 899 for (name, alias) in itertools.zip_longest(columns, column_aliases) 900 ] 901 902 pseudocolumns = self._get_source_pseudocolumns(name) 903 if pseudocolumns: 904 columns = list(columns) 905 columns.extend(c for c in pseudocolumns if c not in columns) 906 907 self._get_source_columns_cache[cache_key] = columns 908 909 return self._get_source_columns_cache[cache_key] 910 911 def _get_source_pseudocolumns(self, name: str) -> t.Sequence[str]: 912 if self.schema.dialect == "snowflake" and self.scope.expression.args.get("connect"): 913 # When there is a CONNECT BY clause, there is only one table being scanned 914 # See: https://docs.snowflake.com/en/sql-reference/constructs/connect-by 915 return ["LEVEL"] 916 return [] 917 918 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 919 if self._source_columns is None: 920 self._source_columns = { 921 source_name: self.get_source_columns(source_name) 922 for source_name, source in itertools.chain( 923 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 924 ) 925 } 926 return self._source_columns 927 928 def _get_unambiguous_columns( 929 self, source_columns: t.Dict[str, t.Sequence[str]] 930 ) -> t.Mapping[str, str]: 931 """ 932 Find all the unambiguous columns in sources. 933 934 Args: 935 source_columns: Mapping of names to source columns. 936 937 Returns: 938 Mapping of column name to source name. 939 """ 940 if not source_columns: 941 return {} 942 943 source_columns_pairs = list(source_columns.items()) 944 945 first_table, first_columns = source_columns_pairs[0] 946 947 if len(source_columns_pairs) == 1: 948 # Performance optimization - avoid copying first_columns if there is only one table. 949 return SingleValuedMapping(first_columns, first_table) 950 951 unambiguous_columns = {col: first_table for col in first_columns} 952 all_columns = set(unambiguous_columns) 953 954 for table, columns in source_columns_pairs[1:]: 955 unique = set(columns) 956 ambiguous = all_columns.intersection(unique) 957 all_columns.update(columns) 958 959 for column in ambiguous: 960 unambiguous_columns.pop(column, None) 961 for column in unique.difference(ambiguous): 962 unambiguous_columns[column] = table 963 964 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
802 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 803 self.scope = scope 804 self.schema = schema 805 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 806 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 807 self._all_columns: t.Optional[t.Set[str]] = None 808 self._infer_schema = infer_schema 809 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
811 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 812 """ 813 Get the table for a column name. 814 815 Args: 816 column_name: The column name to find the table for. 817 Returns: 818 The table name if it can be found/inferred. 819 """ 820 if self._unambiguous_columns is None: 821 self._unambiguous_columns = self._get_unambiguous_columns( 822 self._get_all_source_columns() 823 ) 824 825 table_name = self._unambiguous_columns.get(column_name) 826 827 if not table_name and self._infer_schema: 828 sources_without_schema = tuple( 829 source 830 for source, columns in self._get_all_source_columns().items() 831 if not columns or "*" in columns 832 ) 833 if len(sources_without_schema) == 1: 834 table_name = sources_without_schema[0] 835 836 if table_name not in self.scope.selected_sources: 837 return exp.to_identifier(table_name) 838 839 node, _ = self.scope.selected_sources.get(table_name) 840 841 if isinstance(node, exp.Query): 842 while node and node.alias != table_name: 843 node = node.parent 844 845 node_alias = node.args.get("alias") 846 if node_alias: 847 return exp.to_identifier(node_alias.this) 848 849 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
851 @property 852 def all_columns(self) -> t.Set[str]: 853 """All available columns of all sources in this scope""" 854 if self._all_columns is None: 855 self._all_columns = { 856 column for columns in self._get_all_source_columns().values() for column in columns 857 } 858 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
860 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 861 """Resolve the source columns for a given source `name`.""" 862 cache_key = (name, only_visible) 863 if cache_key not in self._get_source_columns_cache: 864 if name not in self.scope.sources: 865 raise OptimizeError(f"Unknown table: {name}") 866 867 source = self.scope.sources[name] 868 869 if isinstance(source, exp.Table): 870 columns = self.schema.column_names(source, only_visible) 871 elif isinstance(source, Scope) and isinstance( 872 source.expression, (exp.Values, exp.Unnest) 873 ): 874 columns = source.expression.named_selects 875 876 # in bigquery, unnest structs are automatically scoped as tables, so you can 877 # directly select a struct field in a query. 878 # this handles the case where the unnest is statically defined. 879 if self.schema.dialect == "bigquery": 880 if source.expression.is_type(exp.DataType.Type.STRUCT): 881 for k in source.expression.type.expressions: # type: ignore 882 columns.append(k.name) 883 else: 884 columns = source.expression.named_selects 885 886 node, _ = self.scope.selected_sources.get(name) or (None, None) 887 if isinstance(node, Scope): 888 column_aliases = node.expression.alias_column_names 889 elif isinstance(node, exp.Expression): 890 column_aliases = node.alias_column_names 891 else: 892 column_aliases = [] 893 894 if column_aliases: 895 # If the source's columns are aliased, their aliases shadow the corresponding column names. 896 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 897 columns = [ 898 alias or name 899 for (name, alias) in itertools.zip_longest(columns, column_aliases) 900 ] 901 902 pseudocolumns = self._get_source_pseudocolumns(name) 903 if pseudocolumns: 904 columns = list(columns) 905 columns.extend(c for c in pseudocolumns if c not in columns) 906 907 self._get_source_columns_cache[cache_key] = columns 908 909 return self._get_source_columns_cache[cache_key]
Resolve the source columns for a given source name
.