sqlglot.lineage
1from __future__ import annotations 2 3import json 4import logging 5import typing as t 6from dataclasses import dataclass, field 7 8from sqlglot import Schema, exp, maybe_parse 9from sqlglot.errors import SqlglotError 10from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify 11from sqlglot.optimizer.scope import ScopeType 12 13if t.TYPE_CHECKING: 14 from sqlglot.dialects.dialect import DialectType 15 from collections.abc import Iterator, Mapping, Sequence 16 from sqlglot._typing import GraphHTMLArgs 17 from typing_extensions import Unpack 18 19logger = logging.getLogger("sqlglot") 20 21 22@dataclass(frozen=True) 23class Node: 24 name: str 25 expression: exp.Expr 26 source: exp.Expr 27 downstream: list[Node] = field(default_factory=list) 28 source_name: str = "" 29 reference_node_name: str = "" 30 31 # Caller-injected per-node data, populated via the `on_node` hook on lineage() 32 payload: dict[str, t.Any] = field(default_factory=dict) 33 34 def walk(self) -> Iterator[Node]: 35 visited: set[int] = set() 36 queue = [self] 37 while queue: 38 node = queue.pop() 39 node_id = id(node) 40 if node_id in visited: 41 continue 42 visited.add(node_id) 43 yield node 44 queue.extend(reversed(node.downstream)) 45 46 def to_html(self, dialect: DialectType = None, **opts: Unpack[GraphHTMLArgs]) -> GraphHTML: 47 nodes = {} 48 edges = [] 49 50 for node in self.walk(): 51 if isinstance(node.expression, exp.Table): 52 label = f"FROM {node.expression.this}" 53 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 54 group = 1 55 else: 56 label = node.expression.sql(pretty=True, dialect=dialect) 57 source = node.source.transform( 58 lambda n: ( 59 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 60 ), 61 copy=False, 62 ).sql(pretty=True, dialect=dialect) 63 title = f"<pre>{source}</pre>" 64 group = 0 65 66 node_id = id(node) 67 68 nodes[node_id] = { 69 "id": node_id, 70 "label": label, 71 "title": title, 72 "group": group, 73 } 74 75 for d in node.downstream: 76 edges.append({"from": node_id, "to": id(d)}) 77 return GraphHTML(nodes, edges, **opts) 78 79 80@t.overload 81def lineage(column: str | exp.Column, sql: str | exp.Expr, **kwargs: t.Any) -> Node: ... 82 83 84@t.overload 85def lineage(column: None, sql: str | exp.Expr, **kwargs: t.Any) -> dict[str, Node]: ... 86 87 88def lineage( 89 column: str | exp.Column | None, 90 sql: str | exp.Expr, 91 schema: dict | Schema | None = None, 92 sources: Mapping[str, str | exp.Query] | None = None, 93 dialect: DialectType = None, 94 scope: Scope | None = None, 95 trim_selects: bool = True, 96 copy: bool = True, 97 on_node: t.Callable[[Node], None] | None = None, 98 **kwargs, 99) -> Node | dict[str, Node]: 100 """Build the lineage graph for a SQL query. 101 102 If `column` is given, returns the lineage Node for that single output column. 103 If `column` is None, returns a dict mapping every top-level output column name 104 to its lineage Node (with a shared cache so cross-column work is deduplicated). 105 106 Args: 107 column: The column to build the lineage for. Pass None to get all output columns. 108 sql: The SQL string or expression. 109 schema: The schema of tables. 110 sources: A mapping of queries which will be used to continue building lineage. 111 dialect: The dialect of input SQL. 112 scope: A pre-created scope to use instead. 113 trim_selects: Whether to clean up selects by trimming to only relevant columns. 114 copy: Whether to copy the Expr arguments. 115 on_node: Optional callback invoked for every Node created during the walk, 116 after the Node's downstream is populated. Useful for injecting 117 caller-managed data into Node.payload during the walk. 118 **kwargs: Qualification optimizer kwargs. 119 120 Returns: 121 A Node when `column` is provided, or a dict[str, Node] when `column` is None. 122 """ 123 expression = maybe_parse(sql, copy=copy, dialect=dialect) 124 125 if sources: 126 expression = exp.expand( 127 expression, 128 { 129 k: t.cast(exp.Query, maybe_parse(v, copy=copy, dialect=dialect)) 130 for k, v in sources.items() 131 }, 132 dialect=dialect, 133 copy=copy, 134 ) 135 136 if not scope: 137 expression = qualify.qualify( 138 expression, 139 dialect=dialect, 140 schema=schema, 141 **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore 142 ) 143 scope = build_scope(expression) 144 145 if not scope: 146 raise SqlglotError("Cannot build lineage, sql must be SELECT") 147 148 selectable = scope.expression 149 if not isinstance(selectable, exp.Selectable): 150 raise SqlglotError("Cannot build lineage, sql must be a query") 151 152 cache: dict[tuple, Node] = {} 153 scope_meta: dict[int, tuple[bool, dict[str, exp.Expr]]] = {} 154 155 if column is not None: 156 column_name = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name 157 if not any(select.alias_or_name == column_name for select in selectable.selects): 158 raise SqlglotError(f"Cannot find column '{column_name}' in query.") 159 160 return to_node( 161 column_name, 162 scope, 163 dialect, 164 trim_selects=trim_selects, 165 _cache=cache, 166 _scope_meta=scope_meta, 167 on_node=on_node, 168 ) 169 170 result: dict[str, Node] = {} 171 for sel in selectable.selects: 172 name = sel.alias_or_name 173 if not name: 174 raise SqlglotError( 175 f"Cannot fetch lineage for unnamed projection: {sel.sql(dialect=dialect)}." 176 ) 177 178 result[name] = to_node( 179 name, 180 scope, 181 dialect, 182 trim_selects=trim_selects, 183 _cache=cache, 184 _scope_meta=scope_meta, 185 on_node=on_node, 186 ) 187 188 return result 189 190 191def to_node( 192 column: str | int, 193 scope: Scope, 194 dialect: DialectType, 195 scope_name: str | None = None, 196 upstream: Node | None = None, 197 source_name: str | None = None, 198 reference_node_name: str | None = None, 199 trim_selects: bool = True, 200 _cache: dict[tuple, Node] | None = None, 201 _scope_meta: dict[int, tuple[bool, dict[str, exp.Expr]]] | None = None, 202 on_node: t.Callable[[Node], None] | None = None, 203) -> Node: 204 cache_key = (column, id(scope), scope_name, source_name, reference_node_name) 205 206 if _cache is not None and cache_key in _cache: 207 cached_node = _cache[cache_key] 208 if upstream: 209 upstream.downstream.append(cached_node) 210 return cached_node 211 212 # Find the specific select clause that is the source of the column we want. 213 # This can either be a specific, named select or a generic `*` clause. 214 selectable = t.cast(exp.Selectable, scope.expression) 215 if isinstance(column, int): 216 if column >= len(selectable.selects): 217 raise SqlglotError( 218 f"Cannot find column's source with index {column} in query: {selectable.sql(dialect=dialect)}" 219 ) 220 select = selectable.selects[column] 221 else: 222 # Resolving a column to its select scans selectable.selects on every call; 223 # memoize a per-scope {name: select} map and is_star bit instead. 224 if _scope_meta is None: 225 select = next( 226 (s for s in selectable.selects if s.alias_or_name == column), 227 exp.Star() if selectable.is_star else scope.expression, 228 ) 229 else: 230 scope_id = id(scope) 231 meta = _scope_meta.get(scope_id) 232 if meta is None: 233 select_by_name: dict[str, exp.Expr] = {} 234 for sel in selectable.selects: 235 select_by_name.setdefault(sel.alias_or_name, sel) 236 meta = (selectable.is_star, select_by_name) 237 _scope_meta[scope_id] = meta 238 is_star, select_by_name = meta 239 select = select_by_name.get(column, exp.Star() if is_star else scope.expression) 240 241 if isinstance(scope.expression, exp.Subquery): 242 for inner_scope in scope.subquery_scopes: 243 result = to_node( 244 column, 245 scope=inner_scope, 246 dialect=dialect, 247 upstream=upstream, 248 source_name=source_name, 249 reference_node_name=reference_node_name, 250 trim_selects=trim_selects, 251 _cache=_cache, 252 _scope_meta=_scope_meta, 253 on_node=on_node, 254 ) 255 # Skip caching a passed-in upstream returned by an inner SetOp: 256 # a sibling call at the same key with that node as its upstream 257 # would otherwise self-loop on the cache hit. 258 if _cache is not None and result is not upstream: 259 _cache[cache_key] = result 260 return result 261 if isinstance(scope.expression, exp.SetOperation): 262 name = type(scope.expression).__name__.upper() 263 created_setop = upstream is None 264 upstream = upstream or Node(name=name, source=scope.expression, expression=select) 265 266 index = ( 267 column 268 if isinstance(column, int) 269 else next( 270 ( 271 i 272 for i, select in enumerate(selectable.selects) 273 if select.alias_or_name == column or select.is_star 274 ), 275 -1, # mypy will not allow a None here, but a negative index should never be returned 276 ) 277 ) 278 279 if index == -1: 280 raise ValueError(f"Could not find {column} in {scope.expression}") 281 282 for s in scope.union_scopes: 283 to_node( 284 index, 285 scope=s, 286 dialect=dialect, 287 upstream=upstream, 288 source_name=source_name, 289 reference_node_name=reference_node_name, 290 trim_selects=trim_selects, 291 _cache=_cache, 292 _scope_meta=_scope_meta, 293 on_node=on_node, 294 ) 295 296 if _cache is not None and created_setop: 297 _cache[cache_key] = upstream 298 if created_setop and on_node: 299 on_node(upstream) 300 return upstream 301 302 if trim_selects and isinstance(scope.expression, exp.Select): 303 # For better ergonomics in our node labels, replace the full select with 304 # a version that has only the column we care about. 305 # "x", SELECT x, y FROM foo 306 # => "x", SELECT x FROM foo 307 source: exp.Expr = scope.expression.select(select, append=False) 308 else: 309 source = scope.expression 310 311 # Create the node for this step in the lineage chain, and attach it to the previous one. 312 node = Node( 313 name=f"{scope_name}.{column}" if scope_name else str(column), 314 source=source, 315 expression=select, 316 source_name=source_name or "", 317 reference_node_name=reference_node_name or "", 318 ) 319 320 if upstream: 321 upstream.downstream.append(node) 322 323 subquery_scopes = { 324 id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes 325 } 326 327 for subquery in find_all_in_scope(select, *exp.UNWRAPPED_QUERIES): 328 subquery_scope: Scope | None = subquery_scopes.get(id(subquery)) 329 if not subquery_scope: 330 logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") 331 continue 332 333 for name in subquery.named_selects: 334 to_node( 335 name, 336 scope=subquery_scope, 337 dialect=dialect, 338 upstream=node, 339 trim_selects=trim_selects, 340 _cache=_cache, 341 _scope_meta=_scope_meta, 342 on_node=on_node, 343 ) 344 345 # if the select is a star add all scope sources as downstreams 346 if isinstance(select, exp.Star): 347 for src in scope.sources.values(): 348 src_expr = src.expression if isinstance(src, Scope) else src 349 star_node = Node(name=select.sql(comments=False), source=src_expr, expression=src_expr) 350 node.downstream.append(star_node) 351 if on_node: 352 on_node(star_node) 353 354 # Find all columns that went into creating this one to list their lineage nodes. 355 source_columns = set(find_all_in_scope(select, exp.Column)) 356 357 # If the source is a UDTF find columns used in the UDTF to generate the table 358 if isinstance(source, exp.UDTF): 359 source_columns |= set(source.find_all(exp.Column)) 360 derived_tables: Sequence[exp.Expr] = [ 361 src.expression.parent 362 for src in scope.sources.values() 363 if isinstance(src, Scope) and src.is_derived_table and src.expression.parent 364 ] 365 else: 366 derived_tables = scope.derived_tables 367 368 source_names = { 369 dt.alias: dt.comments[0].split()[1] 370 for dt in derived_tables 371 if dt.comments and dt.comments[0].startswith("source: ") 372 } 373 374 pivots = scope.pivots 375 pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None 376 if pivot: 377 # For each aggregation function, the pivot creates a new column for each field in category 378 # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a, 379 # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum' 380 # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs 381 # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest 382 # in the lineage, so lookup the pivot column name by index and map that with the columns used 383 # in the aggregation. 384 # 385 # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b') 386 pivot_columns = pivot.args["columns"] 387 pivot_aggs_count = len(pivot.expressions) 388 389 pivot_column_mapping = {} 390 for i, agg in enumerate(pivot.expressions): 391 agg_cols = list(agg.find_all(exp.Column)) 392 for col_index in range(i, len(pivot_columns), pivot_aggs_count): 393 pivot_column_mapping[pivot_columns[col_index].name] = agg_cols 394 395 for c in source_columns: 396 table = c.table 397 col_source: exp.Table | Scope | None = scope.sources.get(table) 398 399 if isinstance(col_source, Scope): 400 reference_node_name = None 401 if col_source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names: 402 reference_node_name = table 403 elif col_source.scope_type == ScopeType.CTE: 404 selected_node, _ = scope.selected_sources.get(table, (None, None)) 405 reference_node_name = selected_node.name if selected_node else None 406 407 # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. 408 to_node( 409 c.name, 410 scope=col_source, 411 dialect=dialect, 412 scope_name=table, 413 upstream=node, 414 source_name=source_names.get(table) or source_name, 415 reference_node_name=reference_node_name, 416 trim_selects=trim_selects, 417 _cache=_cache, 418 _scope_meta=_scope_meta, 419 on_node=on_node, 420 ) 421 elif pivot and pivot.alias_or_name == c.table: 422 downstream_columns = [] 423 424 column_name = c.name 425 if any(column_name == pivot_column.name for pivot_column in pivot_columns): 426 downstream_columns.extend(pivot_column_mapping[column_name]) 427 else: 428 # The column is not in the pivot, so it must be an implicit column of the 429 # pivoted source -- adapt column to be from the implicit pivoted source. 430 pivot_parent = pivot.parent 431 downstream_columns.append( 432 exp.column(c.this, table=pivot_parent.alias_or_name if pivot_parent else "") 433 ) 434 435 for downstream_column in downstream_columns: 436 table = downstream_column.table 437 col_source = scope.sources.get(table) 438 if isinstance(col_source, Scope): 439 to_node( 440 downstream_column.name, 441 scope=col_source, 442 scope_name=table, 443 dialect=dialect, 444 upstream=node, 445 source_name=source_names.get(table) or source_name, 446 reference_node_name=reference_node_name, 447 trim_selects=trim_selects, 448 _cache=_cache, 449 _scope_meta=_scope_meta, 450 on_node=on_node, 451 ) 452 else: 453 col_expr = col_source or exp.Placeholder() 454 pivot_leaf = Node( 455 name=downstream_column.sql(comments=False), 456 source=col_expr, 457 expression=col_expr, 458 ) 459 node.downstream.append(pivot_leaf) 460 if on_node: 461 on_node(pivot_leaf) 462 else: 463 # The source is not a scope and the column is not in any pivot - we've reached the end 464 # of the line. At this point, if a source is not found it means this column's lineage 465 # is unknown. This can happen if the definition of a source used in a query is not 466 # passed into the `sources` map. 467 col_expr = col_source or exp.Placeholder() 468 leaf = Node(name=c.sql(comments=False), source=col_expr, expression=col_expr) 469 node.downstream.append(leaf) 470 if on_node: 471 on_node(leaf) 472 473 if _cache is not None: 474 _cache[cache_key] = node 475 476 if on_node: 477 on_node(node) 478 479 return node 480 481 482class GraphHTML: 483 """Node to HTML generator using vis.js. 484 485 https://visjs.github.io/vis-network/docs/network/ 486 """ 487 488 def __init__( 489 self, 490 nodes: dict, 491 edges: list, 492 imports: bool = True, 493 options: Mapping[str, object] | None = None, 494 ): 495 self.imports = imports 496 497 self.options = { 498 "height": "500px", 499 "width": "100%", 500 "layout": { 501 "hierarchical": { 502 "enabled": True, 503 "nodeSpacing": 200, 504 "sortMethod": "directed", 505 }, 506 }, 507 "interaction": { 508 "dragNodes": False, 509 "selectable": False, 510 }, 511 "physics": { 512 "enabled": False, 513 }, 514 "edges": { 515 "arrows": "to", 516 }, 517 "nodes": { 518 "font": "20px monaco", 519 "shape": "box", 520 "widthConstraint": { 521 "maximum": 300, 522 }, 523 }, 524 **(options or {}), 525 } 526 527 self.nodes = nodes 528 self.edges = edges 529 530 def __str__(self): 531 nodes = json.dumps(list(self.nodes.values())) 532 edges = json.dumps(self.edges) 533 options = json.dumps(self.options) 534 imports = ( 535 """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script> 536 <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script> 537 <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />""" 538 if self.imports 539 else "" 540 ) 541 542 return f"""<div> 543 <div id="sqlglot-lineage"></div> 544 {imports} 545 <script type="text/javascript"> 546 var nodes = new vis.DataSet({nodes}) 547 nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0]) 548 549 new vis.Network( 550 document.getElementById("sqlglot-lineage"), 551 {{ 552 nodes: nodes, 553 edges: new vis.DataSet({edges}) 554 }}, 555 {options}, 556 ) 557 </script> 558</div>""" 559 560 def _repr_html_(self) -> str: 561 return self.__str__()
logger =
<Logger sqlglot (WARNING)>
@dataclass(frozen=True)
class
Node:
23@dataclass(frozen=True) 24class Node: 25 name: str 26 expression: exp.Expr 27 source: exp.Expr 28 downstream: list[Node] = field(default_factory=list) 29 source_name: str = "" 30 reference_node_name: str = "" 31 32 # Caller-injected per-node data, populated via the `on_node` hook on lineage() 33 payload: dict[str, t.Any] = field(default_factory=dict) 34 35 def walk(self) -> Iterator[Node]: 36 visited: set[int] = set() 37 queue = [self] 38 while queue: 39 node = queue.pop() 40 node_id = id(node) 41 if node_id in visited: 42 continue 43 visited.add(node_id) 44 yield node 45 queue.extend(reversed(node.downstream)) 46 47 def to_html(self, dialect: DialectType = None, **opts: Unpack[GraphHTMLArgs]) -> GraphHTML: 48 nodes = {} 49 edges = [] 50 51 for node in self.walk(): 52 if isinstance(node.expression, exp.Table): 53 label = f"FROM {node.expression.this}" 54 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 55 group = 1 56 else: 57 label = node.expression.sql(pretty=True, dialect=dialect) 58 source = node.source.transform( 59 lambda n: ( 60 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 61 ), 62 copy=False, 63 ).sql(pretty=True, dialect=dialect) 64 title = f"<pre>{source}</pre>" 65 group = 0 66 67 node_id = id(node) 68 69 nodes[node_id] = { 70 "id": node_id, 71 "label": label, 72 "title": title, 73 "group": group, 74 } 75 76 for d in node.downstream: 77 edges.append({"from": node_id, "to": id(d)}) 78 return GraphHTML(nodes, edges, **opts)
Node( name: str, expression: sqlglot.expressions.core.Expr, source: sqlglot.expressions.core.Expr, downstream: list[Node] = <factory>, source_name: str = '', reference_node_name: str = '', payload: dict[str, typing.Any] = <factory>)
expression: sqlglot.expressions.core.Expr
source: sqlglot.expressions.core.Expr
downstream: list[Node]
def
to_html( self, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, **opts: typing_extensions.Unpack[sqlglot._typing.GraphHTMLArgs]) -> GraphHTML:
47 def to_html(self, dialect: DialectType = None, **opts: Unpack[GraphHTMLArgs]) -> GraphHTML: 48 nodes = {} 49 edges = [] 50 51 for node in self.walk(): 52 if isinstance(node.expression, exp.Table): 53 label = f"FROM {node.expression.this}" 54 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 55 group = 1 56 else: 57 label = node.expression.sql(pretty=True, dialect=dialect) 58 source = node.source.transform( 59 lambda n: ( 60 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 61 ), 62 copy=False, 63 ).sql(pretty=True, dialect=dialect) 64 title = f"<pre>{source}</pre>" 65 group = 0 66 67 node_id = id(node) 68 69 nodes[node_id] = { 70 "id": node_id, 71 "label": label, 72 "title": title, 73 "group": group, 74 } 75 76 for d in node.downstream: 77 edges.append({"from": node_id, "to": id(d)}) 78 return GraphHTML(nodes, edges, **opts)
def
lineage( column: str | sqlglot.expressions.core.Column | None, sql: str | sqlglot.expressions.core.Expr, schema: dict | sqlglot.schema.Schema | None = None, sources: Mapping[str, str | sqlglot.expressions.query.Query] | None = None, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, scope: sqlglot.optimizer.scope.Scope | None = None, trim_selects: bool = True, copy: bool = True, on_node: Optional[Callable[[Node], NoneType]] = None, **kwargs) -> Node | dict[str, Node]:
89def lineage( 90 column: str | exp.Column | None, 91 sql: str | exp.Expr, 92 schema: dict | Schema | None = None, 93 sources: Mapping[str, str | exp.Query] | None = None, 94 dialect: DialectType = None, 95 scope: Scope | None = None, 96 trim_selects: bool = True, 97 copy: bool = True, 98 on_node: t.Callable[[Node], None] | None = None, 99 **kwargs, 100) -> Node | dict[str, Node]: 101 """Build the lineage graph for a SQL query. 102 103 If `column` is given, returns the lineage Node for that single output column. 104 If `column` is None, returns a dict mapping every top-level output column name 105 to its lineage Node (with a shared cache so cross-column work is deduplicated). 106 107 Args: 108 column: The column to build the lineage for. Pass None to get all output columns. 109 sql: The SQL string or expression. 110 schema: The schema of tables. 111 sources: A mapping of queries which will be used to continue building lineage. 112 dialect: The dialect of input SQL. 113 scope: A pre-created scope to use instead. 114 trim_selects: Whether to clean up selects by trimming to only relevant columns. 115 copy: Whether to copy the Expr arguments. 116 on_node: Optional callback invoked for every Node created during the walk, 117 after the Node's downstream is populated. Useful for injecting 118 caller-managed data into Node.payload during the walk. 119 **kwargs: Qualification optimizer kwargs. 120 121 Returns: 122 A Node when `column` is provided, or a dict[str, Node] when `column` is None. 123 """ 124 expression = maybe_parse(sql, copy=copy, dialect=dialect) 125 126 if sources: 127 expression = exp.expand( 128 expression, 129 { 130 k: t.cast(exp.Query, maybe_parse(v, copy=copy, dialect=dialect)) 131 for k, v in sources.items() 132 }, 133 dialect=dialect, 134 copy=copy, 135 ) 136 137 if not scope: 138 expression = qualify.qualify( 139 expression, 140 dialect=dialect, 141 schema=schema, 142 **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore 143 ) 144 scope = build_scope(expression) 145 146 if not scope: 147 raise SqlglotError("Cannot build lineage, sql must be SELECT") 148 149 selectable = scope.expression 150 if not isinstance(selectable, exp.Selectable): 151 raise SqlglotError("Cannot build lineage, sql must be a query") 152 153 cache: dict[tuple, Node] = {} 154 scope_meta: dict[int, tuple[bool, dict[str, exp.Expr]]] = {} 155 156 if column is not None: 157 column_name = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name 158 if not any(select.alias_or_name == column_name for select in selectable.selects): 159 raise SqlglotError(f"Cannot find column '{column_name}' in query.") 160 161 return to_node( 162 column_name, 163 scope, 164 dialect, 165 trim_selects=trim_selects, 166 _cache=cache, 167 _scope_meta=scope_meta, 168 on_node=on_node, 169 ) 170 171 result: dict[str, Node] = {} 172 for sel in selectable.selects: 173 name = sel.alias_or_name 174 if not name: 175 raise SqlglotError( 176 f"Cannot fetch lineage for unnamed projection: {sel.sql(dialect=dialect)}." 177 ) 178 179 result[name] = to_node( 180 name, 181 scope, 182 dialect, 183 trim_selects=trim_selects, 184 _cache=cache, 185 _scope_meta=scope_meta, 186 on_node=on_node, 187 ) 188 189 return result
Build the lineage graph for a SQL query.
If column is given, returns the lineage Node for that single output column.
If column is None, returns a dict mapping every top-level output column name
to its lineage Node (with a shared cache so cross-column work is deduplicated).
Arguments:
- column: The column to build the lineage for. Pass None to get all output columns.
- sql: The SQL string or expression.
- schema: The schema of tables.
- sources: A mapping of queries which will be used to continue building lineage.
- dialect: The dialect of input SQL.
- scope: A pre-created scope to use instead.
- trim_selects: Whether to clean up selects by trimming to only relevant columns.
- copy: Whether to copy the Expr arguments.
- on_node: Optional callback invoked for every Node created during the walk, after the Node's downstream is populated. Useful for injecting caller-managed data into Node.payload during the walk.
- **kwargs: Qualification optimizer kwargs.
Returns:
A Node when
columnis provided, or a dict[str, Node] whencolumnis None.
def
to_node( column: str | int, scope: sqlglot.optimizer.scope.Scope, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType], scope_name: str | None = None, upstream: Node | None = None, source_name: str | None = None, reference_node_name: str | None = None, trim_selects: bool = True, _cache: dict[tuple, Node] | None = None, _scope_meta: dict[int, tuple[bool, dict[str, sqlglot.expressions.core.Expr]]] | None = None, on_node: Optional[Callable[[Node], NoneType]] = None) -> Node:
192def to_node( 193 column: str | int, 194 scope: Scope, 195 dialect: DialectType, 196 scope_name: str | None = None, 197 upstream: Node | None = None, 198 source_name: str | None = None, 199 reference_node_name: str | None = None, 200 trim_selects: bool = True, 201 _cache: dict[tuple, Node] | None = None, 202 _scope_meta: dict[int, tuple[bool, dict[str, exp.Expr]]] | None = None, 203 on_node: t.Callable[[Node], None] | None = None, 204) -> Node: 205 cache_key = (column, id(scope), scope_name, source_name, reference_node_name) 206 207 if _cache is not None and cache_key in _cache: 208 cached_node = _cache[cache_key] 209 if upstream: 210 upstream.downstream.append(cached_node) 211 return cached_node 212 213 # Find the specific select clause that is the source of the column we want. 214 # This can either be a specific, named select or a generic `*` clause. 215 selectable = t.cast(exp.Selectable, scope.expression) 216 if isinstance(column, int): 217 if column >= len(selectable.selects): 218 raise SqlglotError( 219 f"Cannot find column's source with index {column} in query: {selectable.sql(dialect=dialect)}" 220 ) 221 select = selectable.selects[column] 222 else: 223 # Resolving a column to its select scans selectable.selects on every call; 224 # memoize a per-scope {name: select} map and is_star bit instead. 225 if _scope_meta is None: 226 select = next( 227 (s for s in selectable.selects if s.alias_or_name == column), 228 exp.Star() if selectable.is_star else scope.expression, 229 ) 230 else: 231 scope_id = id(scope) 232 meta = _scope_meta.get(scope_id) 233 if meta is None: 234 select_by_name: dict[str, exp.Expr] = {} 235 for sel in selectable.selects: 236 select_by_name.setdefault(sel.alias_or_name, sel) 237 meta = (selectable.is_star, select_by_name) 238 _scope_meta[scope_id] = meta 239 is_star, select_by_name = meta 240 select = select_by_name.get(column, exp.Star() if is_star else scope.expression) 241 242 if isinstance(scope.expression, exp.Subquery): 243 for inner_scope in scope.subquery_scopes: 244 result = to_node( 245 column, 246 scope=inner_scope, 247 dialect=dialect, 248 upstream=upstream, 249 source_name=source_name, 250 reference_node_name=reference_node_name, 251 trim_selects=trim_selects, 252 _cache=_cache, 253 _scope_meta=_scope_meta, 254 on_node=on_node, 255 ) 256 # Skip caching a passed-in upstream returned by an inner SetOp: 257 # a sibling call at the same key with that node as its upstream 258 # would otherwise self-loop on the cache hit. 259 if _cache is not None and result is not upstream: 260 _cache[cache_key] = result 261 return result 262 if isinstance(scope.expression, exp.SetOperation): 263 name = type(scope.expression).__name__.upper() 264 created_setop = upstream is None 265 upstream = upstream or Node(name=name, source=scope.expression, expression=select) 266 267 index = ( 268 column 269 if isinstance(column, int) 270 else next( 271 ( 272 i 273 for i, select in enumerate(selectable.selects) 274 if select.alias_or_name == column or select.is_star 275 ), 276 -1, # mypy will not allow a None here, but a negative index should never be returned 277 ) 278 ) 279 280 if index == -1: 281 raise ValueError(f"Could not find {column} in {scope.expression}") 282 283 for s in scope.union_scopes: 284 to_node( 285 index, 286 scope=s, 287 dialect=dialect, 288 upstream=upstream, 289 source_name=source_name, 290 reference_node_name=reference_node_name, 291 trim_selects=trim_selects, 292 _cache=_cache, 293 _scope_meta=_scope_meta, 294 on_node=on_node, 295 ) 296 297 if _cache is not None and created_setop: 298 _cache[cache_key] = upstream 299 if created_setop and on_node: 300 on_node(upstream) 301 return upstream 302 303 if trim_selects and isinstance(scope.expression, exp.Select): 304 # For better ergonomics in our node labels, replace the full select with 305 # a version that has only the column we care about. 306 # "x", SELECT x, y FROM foo 307 # => "x", SELECT x FROM foo 308 source: exp.Expr = scope.expression.select(select, append=False) 309 else: 310 source = scope.expression 311 312 # Create the node for this step in the lineage chain, and attach it to the previous one. 313 node = Node( 314 name=f"{scope_name}.{column}" if scope_name else str(column), 315 source=source, 316 expression=select, 317 source_name=source_name or "", 318 reference_node_name=reference_node_name or "", 319 ) 320 321 if upstream: 322 upstream.downstream.append(node) 323 324 subquery_scopes = { 325 id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes 326 } 327 328 for subquery in find_all_in_scope(select, *exp.UNWRAPPED_QUERIES): 329 subquery_scope: Scope | None = subquery_scopes.get(id(subquery)) 330 if not subquery_scope: 331 logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") 332 continue 333 334 for name in subquery.named_selects: 335 to_node( 336 name, 337 scope=subquery_scope, 338 dialect=dialect, 339 upstream=node, 340 trim_selects=trim_selects, 341 _cache=_cache, 342 _scope_meta=_scope_meta, 343 on_node=on_node, 344 ) 345 346 # if the select is a star add all scope sources as downstreams 347 if isinstance(select, exp.Star): 348 for src in scope.sources.values(): 349 src_expr = src.expression if isinstance(src, Scope) else src 350 star_node = Node(name=select.sql(comments=False), source=src_expr, expression=src_expr) 351 node.downstream.append(star_node) 352 if on_node: 353 on_node(star_node) 354 355 # Find all columns that went into creating this one to list their lineage nodes. 356 source_columns = set(find_all_in_scope(select, exp.Column)) 357 358 # If the source is a UDTF find columns used in the UDTF to generate the table 359 if isinstance(source, exp.UDTF): 360 source_columns |= set(source.find_all(exp.Column)) 361 derived_tables: Sequence[exp.Expr] = [ 362 src.expression.parent 363 for src in scope.sources.values() 364 if isinstance(src, Scope) and src.is_derived_table and src.expression.parent 365 ] 366 else: 367 derived_tables = scope.derived_tables 368 369 source_names = { 370 dt.alias: dt.comments[0].split()[1] 371 for dt in derived_tables 372 if dt.comments and dt.comments[0].startswith("source: ") 373 } 374 375 pivots = scope.pivots 376 pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None 377 if pivot: 378 # For each aggregation function, the pivot creates a new column for each field in category 379 # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a, 380 # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum' 381 # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs 382 # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest 383 # in the lineage, so lookup the pivot column name by index and map that with the columns used 384 # in the aggregation. 385 # 386 # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b') 387 pivot_columns = pivot.args["columns"] 388 pivot_aggs_count = len(pivot.expressions) 389 390 pivot_column_mapping = {} 391 for i, agg in enumerate(pivot.expressions): 392 agg_cols = list(agg.find_all(exp.Column)) 393 for col_index in range(i, len(pivot_columns), pivot_aggs_count): 394 pivot_column_mapping[pivot_columns[col_index].name] = agg_cols 395 396 for c in source_columns: 397 table = c.table 398 col_source: exp.Table | Scope | None = scope.sources.get(table) 399 400 if isinstance(col_source, Scope): 401 reference_node_name = None 402 if col_source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names: 403 reference_node_name = table 404 elif col_source.scope_type == ScopeType.CTE: 405 selected_node, _ = scope.selected_sources.get(table, (None, None)) 406 reference_node_name = selected_node.name if selected_node else None 407 408 # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. 409 to_node( 410 c.name, 411 scope=col_source, 412 dialect=dialect, 413 scope_name=table, 414 upstream=node, 415 source_name=source_names.get(table) or source_name, 416 reference_node_name=reference_node_name, 417 trim_selects=trim_selects, 418 _cache=_cache, 419 _scope_meta=_scope_meta, 420 on_node=on_node, 421 ) 422 elif pivot and pivot.alias_or_name == c.table: 423 downstream_columns = [] 424 425 column_name = c.name 426 if any(column_name == pivot_column.name for pivot_column in pivot_columns): 427 downstream_columns.extend(pivot_column_mapping[column_name]) 428 else: 429 # The column is not in the pivot, so it must be an implicit column of the 430 # pivoted source -- adapt column to be from the implicit pivoted source. 431 pivot_parent = pivot.parent 432 downstream_columns.append( 433 exp.column(c.this, table=pivot_parent.alias_or_name if pivot_parent else "") 434 ) 435 436 for downstream_column in downstream_columns: 437 table = downstream_column.table 438 col_source = scope.sources.get(table) 439 if isinstance(col_source, Scope): 440 to_node( 441 downstream_column.name, 442 scope=col_source, 443 scope_name=table, 444 dialect=dialect, 445 upstream=node, 446 source_name=source_names.get(table) or source_name, 447 reference_node_name=reference_node_name, 448 trim_selects=trim_selects, 449 _cache=_cache, 450 _scope_meta=_scope_meta, 451 on_node=on_node, 452 ) 453 else: 454 col_expr = col_source or exp.Placeholder() 455 pivot_leaf = Node( 456 name=downstream_column.sql(comments=False), 457 source=col_expr, 458 expression=col_expr, 459 ) 460 node.downstream.append(pivot_leaf) 461 if on_node: 462 on_node(pivot_leaf) 463 else: 464 # The source is not a scope and the column is not in any pivot - we've reached the end 465 # of the line. At this point, if a source is not found it means this column's lineage 466 # is unknown. This can happen if the definition of a source used in a query is not 467 # passed into the `sources` map. 468 col_expr = col_source or exp.Placeholder() 469 leaf = Node(name=c.sql(comments=False), source=col_expr, expression=col_expr) 470 node.downstream.append(leaf) 471 if on_node: 472 on_node(leaf) 473 474 if _cache is not None: 475 _cache[cache_key] = node 476 477 if on_node: 478 on_node(node) 479 480 return node
class
GraphHTML:
483class GraphHTML: 484 """Node to HTML generator using vis.js. 485 486 https://visjs.github.io/vis-network/docs/network/ 487 """ 488 489 def __init__( 490 self, 491 nodes: dict, 492 edges: list, 493 imports: bool = True, 494 options: Mapping[str, object] | None = None, 495 ): 496 self.imports = imports 497 498 self.options = { 499 "height": "500px", 500 "width": "100%", 501 "layout": { 502 "hierarchical": { 503 "enabled": True, 504 "nodeSpacing": 200, 505 "sortMethod": "directed", 506 }, 507 }, 508 "interaction": { 509 "dragNodes": False, 510 "selectable": False, 511 }, 512 "physics": { 513 "enabled": False, 514 }, 515 "edges": { 516 "arrows": "to", 517 }, 518 "nodes": { 519 "font": "20px monaco", 520 "shape": "box", 521 "widthConstraint": { 522 "maximum": 300, 523 }, 524 }, 525 **(options or {}), 526 } 527 528 self.nodes = nodes 529 self.edges = edges 530 531 def __str__(self): 532 nodes = json.dumps(list(self.nodes.values())) 533 edges = json.dumps(self.edges) 534 options = json.dumps(self.options) 535 imports = ( 536 """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script> 537 <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script> 538 <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />""" 539 if self.imports 540 else "" 541 ) 542 543 return f"""<div> 544 <div id="sqlglot-lineage"></div> 545 {imports} 546 <script type="text/javascript"> 547 var nodes = new vis.DataSet({nodes}) 548 nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0]) 549 550 new vis.Network( 551 document.getElementById("sqlglot-lineage"), 552 {{ 553 nodes: nodes, 554 edges: new vis.DataSet({edges}) 555 }}, 556 {options}, 557 ) 558 </script> 559</div>""" 560 561 def _repr_html_(self) -> str: 562 return self.__str__()
Node to HTML generator using vis.js.
GraphHTML( nodes: dict, edges: list, imports: bool = True, options: Mapping[str, object] | None = None)
489 def __init__( 490 self, 491 nodes: dict, 492 edges: list, 493 imports: bool = True, 494 options: Mapping[str, object] | None = None, 495 ): 496 self.imports = imports 497 498 self.options = { 499 "height": "500px", 500 "width": "100%", 501 "layout": { 502 "hierarchical": { 503 "enabled": True, 504 "nodeSpacing": 200, 505 "sortMethod": "directed", 506 }, 507 }, 508 "interaction": { 509 "dragNodes": False, 510 "selectable": False, 511 }, 512 "physics": { 513 "enabled": False, 514 }, 515 "edges": { 516 "arrows": "to", 517 }, 518 "nodes": { 519 "font": "20px monaco", 520 "shape": "box", 521 "widthConstraint": { 522 "maximum": 300, 523 }, 524 }, 525 **(options or {}), 526 } 527 528 self.nodes = nodes 529 self.edges = edges