sqlglot.lineage
1from __future__ import annotations 2 3import json 4import logging 5import typing as t 6from dataclasses import dataclass, field 7 8from sqlglot import Schema, exp, maybe_parse 9from sqlglot.errors import SqlglotError 10from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify 11from sqlglot.optimizer.scope import ScopeType 12 13if t.TYPE_CHECKING: 14 from sqlglot.dialects.dialect import DialectType 15 from collections.abc import Iterator, Mapping, Sequence 16 from sqlglot._typing import GraphHTMLArgs 17 from typing_extensions import Unpack 18 19logger = logging.getLogger("sqlglot") 20 21 22@dataclass(frozen=True) 23class Node: 24 name: str 25 expression: exp.Expr 26 source: exp.Expr 27 downstream: list[Node] = field(default_factory=list) 28 source_name: str = "" 29 reference_node_name: str = "" 30 31 def walk(self) -> Iterator[Node]: 32 visited: set[int] = set() 33 queue = [self] 34 while queue: 35 node = queue.pop() 36 node_id = id(node) 37 if node_id in visited: 38 continue 39 visited.add(node_id) 40 yield node 41 queue.extend(reversed(node.downstream)) 42 43 def to_html(self, dialect: DialectType = None, **opts: Unpack[GraphHTMLArgs]) -> GraphHTML: 44 nodes = {} 45 edges = [] 46 47 for node in self.walk(): 48 if isinstance(node.expression, exp.Table): 49 label = f"FROM {node.expression.this}" 50 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 51 group = 1 52 else: 53 label = node.expression.sql(pretty=True, dialect=dialect) 54 source = node.source.transform( 55 lambda n: ( 56 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 57 ), 58 copy=False, 59 ).sql(pretty=True, dialect=dialect) 60 title = f"<pre>{source}</pre>" 61 group = 0 62 63 node_id = id(node) 64 65 nodes[node_id] = { 66 "id": node_id, 67 "label": label, 68 "title": title, 69 "group": group, 70 } 71 72 for d in node.downstream: 73 edges.append({"from": node_id, "to": id(d)}) 74 return GraphHTML(nodes, edges, **opts) 75 76 77def lineage( 78 column: str | exp.Column, 79 sql: str | exp.Expr, 80 schema: dict | Schema | None = None, 81 sources: Mapping[str, str | exp.Query] | None = None, 82 dialect: DialectType = None, 83 scope: Scope | None = None, 84 trim_selects: bool = True, 85 copy: bool = True, 86 **kwargs, 87) -> Node: 88 """Build the lineage graph for a column of a SQL query. 89 90 Args: 91 column: The column to build the lineage for. 92 sql: The SQL string or expression. 93 schema: The schema of tables. 94 sources: A mapping of queries which will be used to continue building lineage. 95 dialect: The dialect of input SQL. 96 scope: A pre-created scope to use instead. 97 trim_selects: Whether to clean up selects by trimming to only relevant columns. 98 copy: Whether to copy the Expr arguments. 99 **kwargs: Qualification optimizer kwargs. 100 101 Returns: 102 A lineage node. 103 """ 104 105 expression = maybe_parse(sql, copy=copy, dialect=dialect) 106 column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name 107 108 if sources: 109 expression = exp.expand( 110 expression, 111 { 112 k: t.cast(exp.Query, maybe_parse(v, copy=copy, dialect=dialect)) 113 for k, v in sources.items() 114 }, 115 dialect=dialect, 116 copy=copy, 117 ) 118 119 if not scope: 120 expression = qualify.qualify( 121 expression, 122 dialect=dialect, 123 schema=schema, 124 **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore 125 ) 126 127 scope = build_scope(expression) 128 129 if not scope: 130 raise SqlglotError("Cannot build lineage, sql must be SELECT") 131 132 selectable = scope.expression 133 if not isinstance(selectable, exp.Selectable) or not any( 134 select.alias_or_name == column for select in selectable.selects 135 ): 136 raise SqlglotError(f"Cannot find column '{column}' in query.") 137 138 return to_node(column, scope, dialect, trim_selects=trim_selects, _cache={}) 139 140 141def to_node( 142 column: str | int, 143 scope: Scope, 144 dialect: DialectType, 145 scope_name: str | None = None, 146 upstream: Node | None = None, 147 source_name: str | None = None, 148 reference_node_name: str | None = None, 149 trim_selects: bool = True, 150 _cache: dict[tuple, Node] | None = None, 151) -> Node: 152 cache_key = (column, id(scope), scope_name, source_name, reference_node_name) 153 154 if _cache is not None and cache_key in _cache: 155 cached_node = _cache[cache_key] 156 if upstream: 157 upstream.downstream.append(cached_node) 158 return cached_node 159 160 # Find the specific select clause that is the source of the column we want. 161 # This can either be a specific, named select or a generic `*` clause. 162 selectable = t.cast(exp.Selectable, scope.expression) 163 if isinstance(column, int): 164 if column >= len(selectable.selects): 165 raise SqlglotError( 166 f"Cannot find column's source with index {column} in query: {selectable.sql(dialect=dialect)}" 167 ) 168 select = selectable.selects[column] 169 else: 170 select = next( 171 (select for select in selectable.selects if select.alias_or_name == column), 172 exp.Star() if selectable.is_star else scope.expression, 173 ) 174 175 if isinstance(scope.expression, exp.Subquery): 176 for inner_scope in scope.subquery_scopes: 177 result = to_node( 178 column, 179 scope=inner_scope, 180 dialect=dialect, 181 upstream=upstream, 182 source_name=source_name, 183 reference_node_name=reference_node_name, 184 trim_selects=trim_selects, 185 _cache=_cache, 186 ) 187 if _cache is not None: 188 _cache[cache_key] = result 189 return result 190 if isinstance(scope.expression, exp.SetOperation): 191 name = type(scope.expression).__name__.upper() 192 upstream = upstream or Node(name=name, source=scope.expression, expression=select) 193 194 index = ( 195 column 196 if isinstance(column, int) 197 else next( 198 ( 199 i 200 for i, select in enumerate(selectable.selects) 201 if select.alias_or_name == column or select.is_star 202 ), 203 -1, # mypy will not allow a None here, but a negative index should never be returned 204 ) 205 ) 206 207 if index == -1: 208 raise ValueError(f"Could not find {column} in {scope.expression}") 209 210 for s in scope.union_scopes: 211 to_node( 212 index, 213 scope=s, 214 dialect=dialect, 215 upstream=upstream, 216 source_name=source_name, 217 reference_node_name=reference_node_name, 218 trim_selects=trim_selects, 219 _cache=_cache, 220 ) 221 222 if _cache is not None: 223 _cache[cache_key] = upstream 224 return upstream 225 226 if trim_selects and isinstance(scope.expression, exp.Select): 227 # For better ergonomics in our node labels, replace the full select with 228 # a version that has only the column we care about. 229 # "x", SELECT x, y FROM foo 230 # => "x", SELECT x FROM foo 231 source: exp.Expr = scope.expression.select(select, append=False) 232 else: 233 source = scope.expression 234 235 # Create the node for this step in the lineage chain, and attach it to the previous one. 236 node = Node( 237 name=f"{scope_name}.{column}" if scope_name else str(column), 238 source=source, 239 expression=select, 240 source_name=source_name or "", 241 reference_node_name=reference_node_name or "", 242 ) 243 244 if upstream: 245 upstream.downstream.append(node) 246 247 subquery_scopes = { 248 id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes 249 } 250 251 for subquery in find_all_in_scope(select, *exp.UNWRAPPED_QUERIES): 252 subquery_scope: Scope | None = subquery_scopes.get(id(subquery)) 253 if not subquery_scope: 254 logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") 255 continue 256 257 for name in subquery.named_selects: 258 to_node( 259 name, 260 scope=subquery_scope, 261 dialect=dialect, 262 upstream=node, 263 trim_selects=trim_selects, 264 _cache=_cache, 265 ) 266 267 # if the select is a star add all scope sources as downstreams 268 if isinstance(select, exp.Star): 269 for src in scope.sources.values(): 270 src_expr = src.expression if isinstance(src, Scope) else src 271 node.downstream.append( 272 Node(name=select.sql(comments=False), source=src_expr, expression=src_expr) 273 ) 274 275 # Find all columns that went into creating this one to list their lineage nodes. 276 source_columns = set(find_all_in_scope(select, exp.Column)) 277 278 # If the source is a UDTF find columns used in the UDTF to generate the table 279 if isinstance(source, exp.UDTF): 280 source_columns |= set(source.find_all(exp.Column)) 281 derived_tables: Sequence[exp.Expr] = [ 282 src.expression.parent 283 for src in scope.sources.values() 284 if isinstance(src, Scope) and src.is_derived_table and src.expression.parent 285 ] 286 else: 287 derived_tables = scope.derived_tables 288 289 source_names = { 290 dt.alias: dt.comments[0].split()[1] 291 for dt in derived_tables 292 if dt.comments and dt.comments[0].startswith("source: ") 293 } 294 295 pivots = scope.pivots 296 pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None 297 if pivot: 298 # For each aggregation function, the pivot creates a new column for each field in category 299 # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a, 300 # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum' 301 # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs 302 # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest 303 # in the lineage, so lookup the pivot column name by index and map that with the columns used 304 # in the aggregation. 305 # 306 # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b') 307 pivot_columns = pivot.args["columns"] 308 pivot_aggs_count = len(pivot.expressions) 309 310 pivot_column_mapping = {} 311 for i, agg in enumerate(pivot.expressions): 312 agg_cols = list(agg.find_all(exp.Column)) 313 for col_index in range(i, len(pivot_columns), pivot_aggs_count): 314 pivot_column_mapping[pivot_columns[col_index].name] = agg_cols 315 316 for c in source_columns: 317 table = c.table 318 col_source: exp.Table | Scope | None = scope.sources.get(table) 319 320 if isinstance(col_source, Scope): 321 reference_node_name = None 322 if col_source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names: 323 reference_node_name = table 324 elif col_source.scope_type == ScopeType.CTE: 325 selected_node, _ = scope.selected_sources.get(table, (None, None)) 326 reference_node_name = selected_node.name if selected_node else None 327 328 # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. 329 to_node( 330 c.name, 331 scope=col_source, 332 dialect=dialect, 333 scope_name=table, 334 upstream=node, 335 source_name=source_names.get(table) or source_name, 336 reference_node_name=reference_node_name, 337 trim_selects=trim_selects, 338 _cache=_cache, 339 ) 340 elif pivot and pivot.alias_or_name == c.table: 341 downstream_columns = [] 342 343 column_name = c.name 344 if any(column_name == pivot_column.name for pivot_column in pivot_columns): 345 downstream_columns.extend(pivot_column_mapping[column_name]) 346 else: 347 # The column is not in the pivot, so it must be an implicit column of the 348 # pivoted source -- adapt column to be from the implicit pivoted source. 349 pivot_parent = pivot.parent 350 downstream_columns.append( 351 exp.column(c.this, table=pivot_parent.alias_or_name if pivot_parent else "") 352 ) 353 354 for downstream_column in downstream_columns: 355 table = downstream_column.table 356 col_source = scope.sources.get(table) 357 if isinstance(col_source, Scope): 358 to_node( 359 downstream_column.name, 360 scope=col_source, 361 scope_name=table, 362 dialect=dialect, 363 upstream=node, 364 source_name=source_names.get(table) or source_name, 365 reference_node_name=reference_node_name, 366 trim_selects=trim_selects, 367 _cache=_cache, 368 ) 369 else: 370 col_expr = col_source or exp.Placeholder() 371 node.downstream.append( 372 Node( 373 name=downstream_column.sql(comments=False), 374 source=col_expr, 375 expression=col_expr, 376 ) 377 ) 378 else: 379 # The source is not a scope and the column is not in any pivot - we've reached the end 380 # of the line. At this point, if a source is not found it means this column's lineage 381 # is unknown. This can happen if the definition of a source used in a query is not 382 # passed into the `sources` map. 383 col_expr = col_source or exp.Placeholder() 384 node.downstream.append( 385 Node(name=c.sql(comments=False), source=col_expr, expression=col_expr) 386 ) 387 388 if _cache is not None: 389 _cache[cache_key] = node 390 391 return node 392 393 394class GraphHTML: 395 """Node to HTML generator using vis.js. 396 397 https://visjs.github.io/vis-network/docs/network/ 398 """ 399 400 def __init__( 401 self, 402 nodes: dict, 403 edges: list, 404 imports: bool = True, 405 options: Mapping[str, object] | None = None, 406 ): 407 self.imports = imports 408 409 self.options = { 410 "height": "500px", 411 "width": "100%", 412 "layout": { 413 "hierarchical": { 414 "enabled": True, 415 "nodeSpacing": 200, 416 "sortMethod": "directed", 417 }, 418 }, 419 "interaction": { 420 "dragNodes": False, 421 "selectable": False, 422 }, 423 "physics": { 424 "enabled": False, 425 }, 426 "edges": { 427 "arrows": "to", 428 }, 429 "nodes": { 430 "font": "20px monaco", 431 "shape": "box", 432 "widthConstraint": { 433 "maximum": 300, 434 }, 435 }, 436 **(options or {}), 437 } 438 439 self.nodes = nodes 440 self.edges = edges 441 442 def __str__(self): 443 nodes = json.dumps(list(self.nodes.values())) 444 edges = json.dumps(self.edges) 445 options = json.dumps(self.options) 446 imports = ( 447 """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script> 448 <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script> 449 <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />""" 450 if self.imports 451 else "" 452 ) 453 454 return f"""<div> 455 <div id="sqlglot-lineage"></div> 456 {imports} 457 <script type="text/javascript"> 458 var nodes = new vis.DataSet({nodes}) 459 nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0]) 460 461 new vis.Network( 462 document.getElementById("sqlglot-lineage"), 463 {{ 464 nodes: nodes, 465 edges: new vis.DataSet({edges}) 466 }}, 467 {options}, 468 ) 469 </script> 470</div>""" 471 472 def _repr_html_(self) -> str: 473 return self.__str__()
logger =
<Logger sqlglot (WARNING)>
@dataclass(frozen=True)
class
Node:
23@dataclass(frozen=True) 24class Node: 25 name: str 26 expression: exp.Expr 27 source: exp.Expr 28 downstream: list[Node] = field(default_factory=list) 29 source_name: str = "" 30 reference_node_name: str = "" 31 32 def walk(self) -> Iterator[Node]: 33 visited: set[int] = set() 34 queue = [self] 35 while queue: 36 node = queue.pop() 37 node_id = id(node) 38 if node_id in visited: 39 continue 40 visited.add(node_id) 41 yield node 42 queue.extend(reversed(node.downstream)) 43 44 def to_html(self, dialect: DialectType = None, **opts: Unpack[GraphHTMLArgs]) -> GraphHTML: 45 nodes = {} 46 edges = [] 47 48 for node in self.walk(): 49 if isinstance(node.expression, exp.Table): 50 label = f"FROM {node.expression.this}" 51 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 52 group = 1 53 else: 54 label = node.expression.sql(pretty=True, dialect=dialect) 55 source = node.source.transform( 56 lambda n: ( 57 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 58 ), 59 copy=False, 60 ).sql(pretty=True, dialect=dialect) 61 title = f"<pre>{source}</pre>" 62 group = 0 63 64 node_id = id(node) 65 66 nodes[node_id] = { 67 "id": node_id, 68 "label": label, 69 "title": title, 70 "group": group, 71 } 72 73 for d in node.downstream: 74 edges.append({"from": node_id, "to": id(d)}) 75 return GraphHTML(nodes, edges, **opts)
Node( name: str, expression: sqlglot.expressions.core.Expr, source: sqlglot.expressions.core.Expr, downstream: list[Node] = <factory>, source_name: str = '', reference_node_name: str = '')
expression: sqlglot.expressions.core.Expr
source: sqlglot.expressions.core.Expr
downstream: list[Node]
def
to_html( self, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, **opts: typing_extensions.Unpack[sqlglot._typing.GraphHTMLArgs]) -> GraphHTML:
44 def to_html(self, dialect: DialectType = None, **opts: Unpack[GraphHTMLArgs]) -> GraphHTML: 45 nodes = {} 46 edges = [] 47 48 for node in self.walk(): 49 if isinstance(node.expression, exp.Table): 50 label = f"FROM {node.expression.this}" 51 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 52 group = 1 53 else: 54 label = node.expression.sql(pretty=True, dialect=dialect) 55 source = node.source.transform( 56 lambda n: ( 57 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 58 ), 59 copy=False, 60 ).sql(pretty=True, dialect=dialect) 61 title = f"<pre>{source}</pre>" 62 group = 0 63 64 node_id = id(node) 65 66 nodes[node_id] = { 67 "id": node_id, 68 "label": label, 69 "title": title, 70 "group": group, 71 } 72 73 for d in node.downstream: 74 edges.append({"from": node_id, "to": id(d)}) 75 return GraphHTML(nodes, edges, **opts)
def
lineage( column: str | sqlglot.expressions.core.Column, sql: str | sqlglot.expressions.core.Expr, schema: dict | sqlglot.schema.Schema | None = None, sources: Mapping[str, str | sqlglot.expressions.query.Query] | None = None, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, scope: sqlglot.optimizer.scope.Scope | None = None, trim_selects: bool = True, copy: bool = True, **kwargs) -> Node:
78def lineage( 79 column: str | exp.Column, 80 sql: str | exp.Expr, 81 schema: dict | Schema | None = None, 82 sources: Mapping[str, str | exp.Query] | None = None, 83 dialect: DialectType = None, 84 scope: Scope | None = None, 85 trim_selects: bool = True, 86 copy: bool = True, 87 **kwargs, 88) -> Node: 89 """Build the lineage graph for a column of a SQL query. 90 91 Args: 92 column: The column to build the lineage for. 93 sql: The SQL string or expression. 94 schema: The schema of tables. 95 sources: A mapping of queries which will be used to continue building lineage. 96 dialect: The dialect of input SQL. 97 scope: A pre-created scope to use instead. 98 trim_selects: Whether to clean up selects by trimming to only relevant columns. 99 copy: Whether to copy the Expr arguments. 100 **kwargs: Qualification optimizer kwargs. 101 102 Returns: 103 A lineage node. 104 """ 105 106 expression = maybe_parse(sql, copy=copy, dialect=dialect) 107 column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name 108 109 if sources: 110 expression = exp.expand( 111 expression, 112 { 113 k: t.cast(exp.Query, maybe_parse(v, copy=copy, dialect=dialect)) 114 for k, v in sources.items() 115 }, 116 dialect=dialect, 117 copy=copy, 118 ) 119 120 if not scope: 121 expression = qualify.qualify( 122 expression, 123 dialect=dialect, 124 schema=schema, 125 **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore 126 ) 127 128 scope = build_scope(expression) 129 130 if not scope: 131 raise SqlglotError("Cannot build lineage, sql must be SELECT") 132 133 selectable = scope.expression 134 if not isinstance(selectable, exp.Selectable) or not any( 135 select.alias_or_name == column for select in selectable.selects 136 ): 137 raise SqlglotError(f"Cannot find column '{column}' in query.") 138 139 return to_node(column, scope, dialect, trim_selects=trim_selects, _cache={})
Build the lineage graph for a column of a SQL query.
Arguments:
- column: The column to build the lineage for.
- sql: The SQL string or expression.
- schema: The schema of tables.
- sources: A mapping of queries which will be used to continue building lineage.
- dialect: The dialect of input SQL.
- scope: A pre-created scope to use instead.
- trim_selects: Whether to clean up selects by trimming to only relevant columns.
- copy: Whether to copy the Expr arguments.
- **kwargs: Qualification optimizer kwargs.
Returns:
A lineage node.
def
to_node( column: str | int, scope: sqlglot.optimizer.scope.Scope, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType], scope_name: str | None = None, upstream: Node | None = None, source_name: str | None = None, reference_node_name: str | None = None, trim_selects: bool = True, _cache: dict[tuple, Node] | None = None) -> Node:
142def to_node( 143 column: str | int, 144 scope: Scope, 145 dialect: DialectType, 146 scope_name: str | None = None, 147 upstream: Node | None = None, 148 source_name: str | None = None, 149 reference_node_name: str | None = None, 150 trim_selects: bool = True, 151 _cache: dict[tuple, Node] | None = None, 152) -> Node: 153 cache_key = (column, id(scope), scope_name, source_name, reference_node_name) 154 155 if _cache is not None and cache_key in _cache: 156 cached_node = _cache[cache_key] 157 if upstream: 158 upstream.downstream.append(cached_node) 159 return cached_node 160 161 # Find the specific select clause that is the source of the column we want. 162 # This can either be a specific, named select or a generic `*` clause. 163 selectable = t.cast(exp.Selectable, scope.expression) 164 if isinstance(column, int): 165 if column >= len(selectable.selects): 166 raise SqlglotError( 167 f"Cannot find column's source with index {column} in query: {selectable.sql(dialect=dialect)}" 168 ) 169 select = selectable.selects[column] 170 else: 171 select = next( 172 (select for select in selectable.selects if select.alias_or_name == column), 173 exp.Star() if selectable.is_star else scope.expression, 174 ) 175 176 if isinstance(scope.expression, exp.Subquery): 177 for inner_scope in scope.subquery_scopes: 178 result = to_node( 179 column, 180 scope=inner_scope, 181 dialect=dialect, 182 upstream=upstream, 183 source_name=source_name, 184 reference_node_name=reference_node_name, 185 trim_selects=trim_selects, 186 _cache=_cache, 187 ) 188 if _cache is not None: 189 _cache[cache_key] = result 190 return result 191 if isinstance(scope.expression, exp.SetOperation): 192 name = type(scope.expression).__name__.upper() 193 upstream = upstream or Node(name=name, source=scope.expression, expression=select) 194 195 index = ( 196 column 197 if isinstance(column, int) 198 else next( 199 ( 200 i 201 for i, select in enumerate(selectable.selects) 202 if select.alias_or_name == column or select.is_star 203 ), 204 -1, # mypy will not allow a None here, but a negative index should never be returned 205 ) 206 ) 207 208 if index == -1: 209 raise ValueError(f"Could not find {column} in {scope.expression}") 210 211 for s in scope.union_scopes: 212 to_node( 213 index, 214 scope=s, 215 dialect=dialect, 216 upstream=upstream, 217 source_name=source_name, 218 reference_node_name=reference_node_name, 219 trim_selects=trim_selects, 220 _cache=_cache, 221 ) 222 223 if _cache is not None: 224 _cache[cache_key] = upstream 225 return upstream 226 227 if trim_selects and isinstance(scope.expression, exp.Select): 228 # For better ergonomics in our node labels, replace the full select with 229 # a version that has only the column we care about. 230 # "x", SELECT x, y FROM foo 231 # => "x", SELECT x FROM foo 232 source: exp.Expr = scope.expression.select(select, append=False) 233 else: 234 source = scope.expression 235 236 # Create the node for this step in the lineage chain, and attach it to the previous one. 237 node = Node( 238 name=f"{scope_name}.{column}" if scope_name else str(column), 239 source=source, 240 expression=select, 241 source_name=source_name or "", 242 reference_node_name=reference_node_name or "", 243 ) 244 245 if upstream: 246 upstream.downstream.append(node) 247 248 subquery_scopes = { 249 id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes 250 } 251 252 for subquery in find_all_in_scope(select, *exp.UNWRAPPED_QUERIES): 253 subquery_scope: Scope | None = subquery_scopes.get(id(subquery)) 254 if not subquery_scope: 255 logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") 256 continue 257 258 for name in subquery.named_selects: 259 to_node( 260 name, 261 scope=subquery_scope, 262 dialect=dialect, 263 upstream=node, 264 trim_selects=trim_selects, 265 _cache=_cache, 266 ) 267 268 # if the select is a star add all scope sources as downstreams 269 if isinstance(select, exp.Star): 270 for src in scope.sources.values(): 271 src_expr = src.expression if isinstance(src, Scope) else src 272 node.downstream.append( 273 Node(name=select.sql(comments=False), source=src_expr, expression=src_expr) 274 ) 275 276 # Find all columns that went into creating this one to list their lineage nodes. 277 source_columns = set(find_all_in_scope(select, exp.Column)) 278 279 # If the source is a UDTF find columns used in the UDTF to generate the table 280 if isinstance(source, exp.UDTF): 281 source_columns |= set(source.find_all(exp.Column)) 282 derived_tables: Sequence[exp.Expr] = [ 283 src.expression.parent 284 for src in scope.sources.values() 285 if isinstance(src, Scope) and src.is_derived_table and src.expression.parent 286 ] 287 else: 288 derived_tables = scope.derived_tables 289 290 source_names = { 291 dt.alias: dt.comments[0].split()[1] 292 for dt in derived_tables 293 if dt.comments and dt.comments[0].startswith("source: ") 294 } 295 296 pivots = scope.pivots 297 pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None 298 if pivot: 299 # For each aggregation function, the pivot creates a new column for each field in category 300 # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a, 301 # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum' 302 # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs 303 # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest 304 # in the lineage, so lookup the pivot column name by index and map that with the columns used 305 # in the aggregation. 306 # 307 # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b') 308 pivot_columns = pivot.args["columns"] 309 pivot_aggs_count = len(pivot.expressions) 310 311 pivot_column_mapping = {} 312 for i, agg in enumerate(pivot.expressions): 313 agg_cols = list(agg.find_all(exp.Column)) 314 for col_index in range(i, len(pivot_columns), pivot_aggs_count): 315 pivot_column_mapping[pivot_columns[col_index].name] = agg_cols 316 317 for c in source_columns: 318 table = c.table 319 col_source: exp.Table | Scope | None = scope.sources.get(table) 320 321 if isinstance(col_source, Scope): 322 reference_node_name = None 323 if col_source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names: 324 reference_node_name = table 325 elif col_source.scope_type == ScopeType.CTE: 326 selected_node, _ = scope.selected_sources.get(table, (None, None)) 327 reference_node_name = selected_node.name if selected_node else None 328 329 # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. 330 to_node( 331 c.name, 332 scope=col_source, 333 dialect=dialect, 334 scope_name=table, 335 upstream=node, 336 source_name=source_names.get(table) or source_name, 337 reference_node_name=reference_node_name, 338 trim_selects=trim_selects, 339 _cache=_cache, 340 ) 341 elif pivot and pivot.alias_or_name == c.table: 342 downstream_columns = [] 343 344 column_name = c.name 345 if any(column_name == pivot_column.name for pivot_column in pivot_columns): 346 downstream_columns.extend(pivot_column_mapping[column_name]) 347 else: 348 # The column is not in the pivot, so it must be an implicit column of the 349 # pivoted source -- adapt column to be from the implicit pivoted source. 350 pivot_parent = pivot.parent 351 downstream_columns.append( 352 exp.column(c.this, table=pivot_parent.alias_or_name if pivot_parent else "") 353 ) 354 355 for downstream_column in downstream_columns: 356 table = downstream_column.table 357 col_source = scope.sources.get(table) 358 if isinstance(col_source, Scope): 359 to_node( 360 downstream_column.name, 361 scope=col_source, 362 scope_name=table, 363 dialect=dialect, 364 upstream=node, 365 source_name=source_names.get(table) or source_name, 366 reference_node_name=reference_node_name, 367 trim_selects=trim_selects, 368 _cache=_cache, 369 ) 370 else: 371 col_expr = col_source or exp.Placeholder() 372 node.downstream.append( 373 Node( 374 name=downstream_column.sql(comments=False), 375 source=col_expr, 376 expression=col_expr, 377 ) 378 ) 379 else: 380 # The source is not a scope and the column is not in any pivot - we've reached the end 381 # of the line. At this point, if a source is not found it means this column's lineage 382 # is unknown. This can happen if the definition of a source used in a query is not 383 # passed into the `sources` map. 384 col_expr = col_source or exp.Placeholder() 385 node.downstream.append( 386 Node(name=c.sql(comments=False), source=col_expr, expression=col_expr) 387 ) 388 389 if _cache is not None: 390 _cache[cache_key] = node 391 392 return node
class
GraphHTML:
395class GraphHTML: 396 """Node to HTML generator using vis.js. 397 398 https://visjs.github.io/vis-network/docs/network/ 399 """ 400 401 def __init__( 402 self, 403 nodes: dict, 404 edges: list, 405 imports: bool = True, 406 options: Mapping[str, object] | None = None, 407 ): 408 self.imports = imports 409 410 self.options = { 411 "height": "500px", 412 "width": "100%", 413 "layout": { 414 "hierarchical": { 415 "enabled": True, 416 "nodeSpacing": 200, 417 "sortMethod": "directed", 418 }, 419 }, 420 "interaction": { 421 "dragNodes": False, 422 "selectable": False, 423 }, 424 "physics": { 425 "enabled": False, 426 }, 427 "edges": { 428 "arrows": "to", 429 }, 430 "nodes": { 431 "font": "20px monaco", 432 "shape": "box", 433 "widthConstraint": { 434 "maximum": 300, 435 }, 436 }, 437 **(options or {}), 438 } 439 440 self.nodes = nodes 441 self.edges = edges 442 443 def __str__(self): 444 nodes = json.dumps(list(self.nodes.values())) 445 edges = json.dumps(self.edges) 446 options = json.dumps(self.options) 447 imports = ( 448 """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script> 449 <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script> 450 <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />""" 451 if self.imports 452 else "" 453 ) 454 455 return f"""<div> 456 <div id="sqlglot-lineage"></div> 457 {imports} 458 <script type="text/javascript"> 459 var nodes = new vis.DataSet({nodes}) 460 nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0]) 461 462 new vis.Network( 463 document.getElementById("sqlglot-lineage"), 464 {{ 465 nodes: nodes, 466 edges: new vis.DataSet({edges}) 467 }}, 468 {options}, 469 ) 470 </script> 471</div>""" 472 473 def _repr_html_(self) -> str: 474 return self.__str__()
Node to HTML generator using vis.js.
GraphHTML( nodes: dict, edges: list, imports: bool = True, options: Mapping[str, object] | None = None)
401 def __init__( 402 self, 403 nodes: dict, 404 edges: list, 405 imports: bool = True, 406 options: Mapping[str, object] | None = None, 407 ): 408 self.imports = imports 409 410 self.options = { 411 "height": "500px", 412 "width": "100%", 413 "layout": { 414 "hierarchical": { 415 "enabled": True, 416 "nodeSpacing": 200, 417 "sortMethod": "directed", 418 }, 419 }, 420 "interaction": { 421 "dragNodes": False, 422 "selectable": False, 423 }, 424 "physics": { 425 "enabled": False, 426 }, 427 "edges": { 428 "arrows": "to", 429 }, 430 "nodes": { 431 "font": "20px monaco", 432 "shape": "box", 433 "widthConstraint": { 434 "maximum": 300, 435 }, 436 }, 437 **(options or {}), 438 } 439 440 self.nodes = nodes 441 self.edges = edges