sqlglot.lineage
1from __future__ import annotations 2 3import json 4import logging 5import typing as t 6from dataclasses import dataclass, field 7 8from sqlglot import Schema, exp, maybe_parse 9from sqlglot.errors import SqlglotError 10from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify 11from sqlglot.optimizer.scope import ScopeType 12 13if t.TYPE_CHECKING: 14 from sqlglot.dialects.dialect import DialectType 15 16logger = logging.getLogger("sqlglot") 17 18 19@dataclass(frozen=True) 20class Node: 21 name: str 22 expression: exp.Expression 23 source: exp.Expression 24 downstream: t.List[Node] = field(default_factory=list) 25 source_name: str = "" 26 reference_node_name: str = "" 27 28 def walk(self) -> t.Iterator[Node]: 29 yield self 30 31 for d in self.downstream: 32 yield from d.walk() 33 34 def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: 35 nodes = {} 36 edges = [] 37 38 for node in self.walk(): 39 if isinstance(node.expression, exp.Table): 40 label = f"FROM {node.expression.this}" 41 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 42 group = 1 43 else: 44 label = node.expression.sql(pretty=True, dialect=dialect) 45 source = node.source.transform( 46 lambda n: ( 47 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 48 ), 49 copy=False, 50 ).sql(pretty=True, dialect=dialect) 51 title = f"<pre>{source}</pre>" 52 group = 0 53 54 node_id = id(node) 55 56 nodes[node_id] = { 57 "id": node_id, 58 "label": label, 59 "title": title, 60 "group": group, 61 } 62 63 for d in node.downstream: 64 edges.append({"from": node_id, "to": id(d)}) 65 return GraphHTML(nodes, edges, **opts) 66 67 68def lineage( 69 column: str | exp.Column, 70 sql: str | exp.Expression, 71 schema: t.Optional[t.Dict | Schema] = None, 72 sources: t.Optional[t.Mapping[str, str | exp.Query]] = None, 73 dialect: DialectType = None, 74 scope: t.Optional[Scope] = None, 75 trim_selects: bool = True, 76 copy: bool = True, 77 **kwargs, 78) -> Node: 79 """Build the lineage graph for a column of a SQL query. 80 81 Args: 82 column: The column to build the lineage for. 83 sql: The SQL string or expression. 84 schema: The schema of tables. 85 sources: A mapping of queries which will be used to continue building lineage. 86 dialect: The dialect of input SQL. 87 scope: A pre-created scope to use instead. 88 trim_selects: Whether to clean up selects by trimming to only relevant columns. 89 copy: Whether to copy the Expression arguments. 90 **kwargs: Qualification optimizer kwargs. 91 92 Returns: 93 A lineage node. 94 """ 95 96 expression = maybe_parse(sql, copy=copy, dialect=dialect) 97 column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name 98 99 if sources: 100 expression = exp.expand( 101 expression, 102 { 103 k: t.cast(exp.Query, maybe_parse(v, copy=copy, dialect=dialect)) 104 for k, v in sources.items() 105 }, 106 dialect=dialect, 107 copy=copy, 108 ) 109 110 if not scope: 111 expression = qualify.qualify( 112 expression, 113 dialect=dialect, 114 schema=schema, 115 **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore 116 ) 117 118 scope = build_scope(expression) 119 120 if not scope: 121 raise SqlglotError("Cannot build lineage, sql must be SELECT") 122 123 if not any(select.alias_or_name == column for select in scope.expression.selects): 124 raise SqlglotError(f"Cannot find column '{column}' in query.") 125 126 return to_node(column, scope, dialect, trim_selects=trim_selects) 127 128 129def to_node( 130 column: str | int, 131 scope: Scope, 132 dialect: DialectType, 133 scope_name: t.Optional[str] = None, 134 upstream: t.Optional[Node] = None, 135 source_name: t.Optional[str] = None, 136 reference_node_name: t.Optional[str] = None, 137 trim_selects: bool = True, 138) -> Node: 139 # Find the specific select clause that is the source of the column we want. 140 # This can either be a specific, named select or a generic `*` clause. 141 select = ( 142 scope.expression.selects[column] 143 if isinstance(column, int) 144 else next( 145 (select for select in scope.expression.selects if select.alias_or_name == column), 146 exp.Star() if scope.expression.is_star else scope.expression, 147 ) 148 ) 149 150 if isinstance(scope.expression, exp.Subquery): 151 for source in scope.subquery_scopes: 152 return to_node( 153 column, 154 scope=source, 155 dialect=dialect, 156 upstream=upstream, 157 source_name=source_name, 158 reference_node_name=reference_node_name, 159 trim_selects=trim_selects, 160 ) 161 if isinstance(scope.expression, exp.SetOperation): 162 name = type(scope.expression).__name__.upper() 163 upstream = upstream or Node(name=name, source=scope.expression, expression=select) 164 165 index = ( 166 column 167 if isinstance(column, int) 168 else next( 169 ( 170 i 171 for i, select in enumerate(scope.expression.selects) 172 if select.alias_or_name == column or select.is_star 173 ), 174 -1, # mypy will not allow a None here, but a negative index should never be returned 175 ) 176 ) 177 178 if index == -1: 179 raise ValueError(f"Could not find {column} in {scope.expression}") 180 181 for s in scope.union_scopes: 182 to_node( 183 index, 184 scope=s, 185 dialect=dialect, 186 upstream=upstream, 187 source_name=source_name, 188 reference_node_name=reference_node_name, 189 trim_selects=trim_selects, 190 ) 191 192 return upstream 193 194 if trim_selects and isinstance(scope.expression, exp.Select): 195 # For better ergonomics in our node labels, replace the full select with 196 # a version that has only the column we care about. 197 # "x", SELECT x, y FROM foo 198 # => "x", SELECT x FROM foo 199 source = t.cast(exp.Expression, scope.expression.select(select, append=False)) 200 else: 201 source = scope.expression 202 203 # Create the node for this step in the lineage chain, and attach it to the previous one. 204 node = Node( 205 name=f"{scope_name}.{column}" if scope_name else str(column), 206 source=source, 207 expression=select, 208 source_name=source_name or "", 209 reference_node_name=reference_node_name or "", 210 ) 211 212 if upstream: 213 upstream.downstream.append(node) 214 215 subquery_scopes = { 216 id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes 217 } 218 219 for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES): 220 subquery_scope = subquery_scopes.get(id(subquery)) 221 if not subquery_scope: 222 logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") 223 continue 224 225 for name in subquery.named_selects: 226 to_node( 227 name, 228 scope=subquery_scope, 229 dialect=dialect, 230 upstream=node, 231 trim_selects=trim_selects, 232 ) 233 234 # if the select is a star add all scope sources as downstreams 235 if isinstance(select, exp.Star): 236 for source in scope.sources.values(): 237 if isinstance(source, Scope): 238 source = source.expression 239 node.downstream.append( 240 Node(name=select.sql(comments=False), source=source, expression=source) 241 ) 242 243 # Find all columns that went into creating this one to list their lineage nodes. 244 source_columns = set(find_all_in_scope(select, exp.Column)) 245 246 # If the source is a UDTF find columns used in the UDTF to generate the table 247 if isinstance(source, exp.UDTF): 248 source_columns |= set(source.find_all(exp.Column)) 249 derived_tables = [ 250 source.expression.parent 251 for source in scope.sources.values() 252 if isinstance(source, Scope) and source.is_derived_table 253 ] 254 else: 255 derived_tables = scope.derived_tables 256 257 source_names = { 258 dt.alias: dt.comments[0].split()[1] 259 for dt in derived_tables 260 if dt.comments and dt.comments[0].startswith("source: ") 261 } 262 263 pivots = scope.pivots 264 pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None 265 if pivot: 266 # For each aggregation function, the pivot creates a new column for each field in category 267 # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a, 268 # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum' 269 # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs 270 # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest 271 # in the lineage, so lookup the pivot column name by index and map that with the columns used 272 # in the aggregation. 273 # 274 # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b') 275 pivot_columns = pivot.args["columns"] 276 pivot_aggs_count = len(pivot.expressions) 277 278 pivot_column_mapping = {} 279 for i, agg in enumerate(pivot.expressions): 280 agg_cols = list(agg.find_all(exp.Column)) 281 for col_index in range(i, len(pivot_columns), pivot_aggs_count): 282 pivot_column_mapping[pivot_columns[col_index].name] = agg_cols 283 284 for c in source_columns: 285 table = c.table 286 source = scope.sources.get(table) 287 288 if isinstance(source, Scope): 289 reference_node_name = None 290 if source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names: 291 reference_node_name = table 292 elif source.scope_type == ScopeType.CTE: 293 selected_node, _ = scope.selected_sources.get(table, (None, None)) 294 reference_node_name = selected_node.name if selected_node else None 295 296 # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. 297 to_node( 298 c.name, 299 scope=source, 300 dialect=dialect, 301 scope_name=table, 302 upstream=node, 303 source_name=source_names.get(table) or source_name, 304 reference_node_name=reference_node_name, 305 trim_selects=trim_selects, 306 ) 307 elif pivot and pivot.alias_or_name == c.table: 308 downstream_columns = [] 309 310 column_name = c.name 311 if any(column_name == pivot_column.name for pivot_column in pivot_columns): 312 downstream_columns.extend(pivot_column_mapping[column_name]) 313 else: 314 # The column is not in the pivot, so it must be an implicit column of the 315 # pivoted source -- adapt column to be from the implicit pivoted source. 316 downstream_columns.append(exp.column(c.this, table=pivot.parent.alias_or_name)) 317 318 for downstream_column in downstream_columns: 319 table = downstream_column.table 320 source = scope.sources.get(table) 321 if isinstance(source, Scope): 322 to_node( 323 downstream_column.name, 324 scope=source, 325 scope_name=table, 326 dialect=dialect, 327 upstream=node, 328 source_name=source_names.get(table) or source_name, 329 reference_node_name=reference_node_name, 330 trim_selects=trim_selects, 331 ) 332 else: 333 source = source or exp.Placeholder() 334 node.downstream.append( 335 Node( 336 name=downstream_column.sql(comments=False), 337 source=source, 338 expression=source, 339 ) 340 ) 341 else: 342 # The source is not a scope and the column is not in any pivot - we've reached the end 343 # of the line. At this point, if a source is not found it means this column's lineage 344 # is unknown. This can happen if the definition of a source used in a query is not 345 # passed into the `sources` map. 346 source = source or exp.Placeholder() 347 node.downstream.append( 348 Node(name=c.sql(comments=False), source=source, expression=source) 349 ) 350 351 return node 352 353 354class GraphHTML: 355 """Node to HTML generator using vis.js. 356 357 https://visjs.github.io/vis-network/docs/network/ 358 """ 359 360 def __init__( 361 self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None 362 ): 363 self.imports = imports 364 365 self.options = { 366 "height": "500px", 367 "width": "100%", 368 "layout": { 369 "hierarchical": { 370 "enabled": True, 371 "nodeSpacing": 200, 372 "sortMethod": "directed", 373 }, 374 }, 375 "interaction": { 376 "dragNodes": False, 377 "selectable": False, 378 }, 379 "physics": { 380 "enabled": False, 381 }, 382 "edges": { 383 "arrows": "to", 384 }, 385 "nodes": { 386 "font": "20px monaco", 387 "shape": "box", 388 "widthConstraint": { 389 "maximum": 300, 390 }, 391 }, 392 **(options or {}), 393 } 394 395 self.nodes = nodes 396 self.edges = edges 397 398 def __str__(self): 399 nodes = json.dumps(list(self.nodes.values())) 400 edges = json.dumps(self.edges) 401 options = json.dumps(self.options) 402 imports = ( 403 """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script> 404 <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script> 405 <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />""" 406 if self.imports 407 else "" 408 ) 409 410 return f"""<div> 411 <div id="sqlglot-lineage"></div> 412 {imports} 413 <script type="text/javascript"> 414 var nodes = new vis.DataSet({nodes}) 415 nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0]) 416 417 new vis.Network( 418 document.getElementById("sqlglot-lineage"), 419 {{ 420 nodes: nodes, 421 edges: new vis.DataSet({edges}) 422 }}, 423 {options}, 424 ) 425 </script> 426</div>""" 427 428 def _repr_html_(self) -> str: 429 return self.__str__()
logger =
<Logger sqlglot (WARNING)>
@dataclass(frozen=True)
class
Node:
20@dataclass(frozen=True) 21class Node: 22 name: str 23 expression: exp.Expression 24 source: exp.Expression 25 downstream: t.List[Node] = field(default_factory=list) 26 source_name: str = "" 27 reference_node_name: str = "" 28 29 def walk(self) -> t.Iterator[Node]: 30 yield self 31 32 for d in self.downstream: 33 yield from d.walk() 34 35 def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: 36 nodes = {} 37 edges = [] 38 39 for node in self.walk(): 40 if isinstance(node.expression, exp.Table): 41 label = f"FROM {node.expression.this}" 42 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 43 group = 1 44 else: 45 label = node.expression.sql(pretty=True, dialect=dialect) 46 source = node.source.transform( 47 lambda n: ( 48 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 49 ), 50 copy=False, 51 ).sql(pretty=True, dialect=dialect) 52 title = f"<pre>{source}</pre>" 53 group = 0 54 55 node_id = id(node) 56 57 nodes[node_id] = { 58 "id": node_id, 59 "label": label, 60 "title": title, 61 "group": group, 62 } 63 64 for d in node.downstream: 65 edges.append({"from": node_id, "to": id(d)}) 66 return GraphHTML(nodes, edges, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[Node] = <factory>, source_name: str = '', reference_node_name: str = '')
expression: sqlglot.expressions.Expression
source: sqlglot.expressions.Expression
downstream: List[Node]
def
to_html( self, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, **opts) -> GraphHTML:
35 def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: 36 nodes = {} 37 edges = [] 38 39 for node in self.walk(): 40 if isinstance(node.expression, exp.Table): 41 label = f"FROM {node.expression.this}" 42 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 43 group = 1 44 else: 45 label = node.expression.sql(pretty=True, dialect=dialect) 46 source = node.source.transform( 47 lambda n: ( 48 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 49 ), 50 copy=False, 51 ).sql(pretty=True, dialect=dialect) 52 title = f"<pre>{source}</pre>" 53 group = 0 54 55 node_id = id(node) 56 57 nodes[node_id] = { 58 "id": node_id, 59 "label": label, 60 "title": title, 61 "group": group, 62 } 63 64 for d in node.downstream: 65 edges.append({"from": node_id, "to": id(d)}) 66 return GraphHTML(nodes, edges, **opts)
def
lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Mapping[str, str | sqlglot.expressions.Query]] = None, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, scope: Optional[sqlglot.optimizer.scope.Scope] = None, trim_selects: bool = True, copy: bool = True, **kwargs) -> Node:
69def lineage( 70 column: str | exp.Column, 71 sql: str | exp.Expression, 72 schema: t.Optional[t.Dict | Schema] = None, 73 sources: t.Optional[t.Mapping[str, str | exp.Query]] = None, 74 dialect: DialectType = None, 75 scope: t.Optional[Scope] = None, 76 trim_selects: bool = True, 77 copy: bool = True, 78 **kwargs, 79) -> Node: 80 """Build the lineage graph for a column of a SQL query. 81 82 Args: 83 column: The column to build the lineage for. 84 sql: The SQL string or expression. 85 schema: The schema of tables. 86 sources: A mapping of queries which will be used to continue building lineage. 87 dialect: The dialect of input SQL. 88 scope: A pre-created scope to use instead. 89 trim_selects: Whether to clean up selects by trimming to only relevant columns. 90 copy: Whether to copy the Expression arguments. 91 **kwargs: Qualification optimizer kwargs. 92 93 Returns: 94 A lineage node. 95 """ 96 97 expression = maybe_parse(sql, copy=copy, dialect=dialect) 98 column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name 99 100 if sources: 101 expression = exp.expand( 102 expression, 103 { 104 k: t.cast(exp.Query, maybe_parse(v, copy=copy, dialect=dialect)) 105 for k, v in sources.items() 106 }, 107 dialect=dialect, 108 copy=copy, 109 ) 110 111 if not scope: 112 expression = qualify.qualify( 113 expression, 114 dialect=dialect, 115 schema=schema, 116 **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore 117 ) 118 119 scope = build_scope(expression) 120 121 if not scope: 122 raise SqlglotError("Cannot build lineage, sql must be SELECT") 123 124 if not any(select.alias_or_name == column for select in scope.expression.selects): 125 raise SqlglotError(f"Cannot find column '{column}' in query.") 126 127 return to_node(column, scope, dialect, trim_selects=trim_selects)
Build the lineage graph for a column of a SQL query.
Arguments:
- column: The column to build the lineage for.
- sql: The SQL string or expression.
- schema: The schema of tables.
- sources: A mapping of queries which will be used to continue building lineage.
- dialect: The dialect of input SQL.
- scope: A pre-created scope to use instead.
- trim_selects: Whether to clean up selects by trimming to only relevant columns.
- copy: Whether to copy the Expression arguments.
- **kwargs: Qualification optimizer kwargs.
Returns:
A lineage node.
def
to_node( column: str | int, scope: sqlglot.optimizer.scope.Scope, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType], scope_name: Optional[str] = None, upstream: Optional[Node] = None, source_name: Optional[str] = None, reference_node_name: Optional[str] = None, trim_selects: bool = True) -> Node:
130def to_node( 131 column: str | int, 132 scope: Scope, 133 dialect: DialectType, 134 scope_name: t.Optional[str] = None, 135 upstream: t.Optional[Node] = None, 136 source_name: t.Optional[str] = None, 137 reference_node_name: t.Optional[str] = None, 138 trim_selects: bool = True, 139) -> Node: 140 # Find the specific select clause that is the source of the column we want. 141 # This can either be a specific, named select or a generic `*` clause. 142 select = ( 143 scope.expression.selects[column] 144 if isinstance(column, int) 145 else next( 146 (select for select in scope.expression.selects if select.alias_or_name == column), 147 exp.Star() if scope.expression.is_star else scope.expression, 148 ) 149 ) 150 151 if isinstance(scope.expression, exp.Subquery): 152 for source in scope.subquery_scopes: 153 return to_node( 154 column, 155 scope=source, 156 dialect=dialect, 157 upstream=upstream, 158 source_name=source_name, 159 reference_node_name=reference_node_name, 160 trim_selects=trim_selects, 161 ) 162 if isinstance(scope.expression, exp.SetOperation): 163 name = type(scope.expression).__name__.upper() 164 upstream = upstream or Node(name=name, source=scope.expression, expression=select) 165 166 index = ( 167 column 168 if isinstance(column, int) 169 else next( 170 ( 171 i 172 for i, select in enumerate(scope.expression.selects) 173 if select.alias_or_name == column or select.is_star 174 ), 175 -1, # mypy will not allow a None here, but a negative index should never be returned 176 ) 177 ) 178 179 if index == -1: 180 raise ValueError(f"Could not find {column} in {scope.expression}") 181 182 for s in scope.union_scopes: 183 to_node( 184 index, 185 scope=s, 186 dialect=dialect, 187 upstream=upstream, 188 source_name=source_name, 189 reference_node_name=reference_node_name, 190 trim_selects=trim_selects, 191 ) 192 193 return upstream 194 195 if trim_selects and isinstance(scope.expression, exp.Select): 196 # For better ergonomics in our node labels, replace the full select with 197 # a version that has only the column we care about. 198 # "x", SELECT x, y FROM foo 199 # => "x", SELECT x FROM foo 200 source = t.cast(exp.Expression, scope.expression.select(select, append=False)) 201 else: 202 source = scope.expression 203 204 # Create the node for this step in the lineage chain, and attach it to the previous one. 205 node = Node( 206 name=f"{scope_name}.{column}" if scope_name else str(column), 207 source=source, 208 expression=select, 209 source_name=source_name or "", 210 reference_node_name=reference_node_name or "", 211 ) 212 213 if upstream: 214 upstream.downstream.append(node) 215 216 subquery_scopes = { 217 id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes 218 } 219 220 for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES): 221 subquery_scope = subquery_scopes.get(id(subquery)) 222 if not subquery_scope: 223 logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") 224 continue 225 226 for name in subquery.named_selects: 227 to_node( 228 name, 229 scope=subquery_scope, 230 dialect=dialect, 231 upstream=node, 232 trim_selects=trim_selects, 233 ) 234 235 # if the select is a star add all scope sources as downstreams 236 if isinstance(select, exp.Star): 237 for source in scope.sources.values(): 238 if isinstance(source, Scope): 239 source = source.expression 240 node.downstream.append( 241 Node(name=select.sql(comments=False), source=source, expression=source) 242 ) 243 244 # Find all columns that went into creating this one to list their lineage nodes. 245 source_columns = set(find_all_in_scope(select, exp.Column)) 246 247 # If the source is a UDTF find columns used in the UDTF to generate the table 248 if isinstance(source, exp.UDTF): 249 source_columns |= set(source.find_all(exp.Column)) 250 derived_tables = [ 251 source.expression.parent 252 for source in scope.sources.values() 253 if isinstance(source, Scope) and source.is_derived_table 254 ] 255 else: 256 derived_tables = scope.derived_tables 257 258 source_names = { 259 dt.alias: dt.comments[0].split()[1] 260 for dt in derived_tables 261 if dt.comments and dt.comments[0].startswith("source: ") 262 } 263 264 pivots = scope.pivots 265 pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None 266 if pivot: 267 # For each aggregation function, the pivot creates a new column for each field in category 268 # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a, 269 # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum' 270 # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs 271 # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest 272 # in the lineage, so lookup the pivot column name by index and map that with the columns used 273 # in the aggregation. 274 # 275 # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b') 276 pivot_columns = pivot.args["columns"] 277 pivot_aggs_count = len(pivot.expressions) 278 279 pivot_column_mapping = {} 280 for i, agg in enumerate(pivot.expressions): 281 agg_cols = list(agg.find_all(exp.Column)) 282 for col_index in range(i, len(pivot_columns), pivot_aggs_count): 283 pivot_column_mapping[pivot_columns[col_index].name] = agg_cols 284 285 for c in source_columns: 286 table = c.table 287 source = scope.sources.get(table) 288 289 if isinstance(source, Scope): 290 reference_node_name = None 291 if source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names: 292 reference_node_name = table 293 elif source.scope_type == ScopeType.CTE: 294 selected_node, _ = scope.selected_sources.get(table, (None, None)) 295 reference_node_name = selected_node.name if selected_node else None 296 297 # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. 298 to_node( 299 c.name, 300 scope=source, 301 dialect=dialect, 302 scope_name=table, 303 upstream=node, 304 source_name=source_names.get(table) or source_name, 305 reference_node_name=reference_node_name, 306 trim_selects=trim_selects, 307 ) 308 elif pivot and pivot.alias_or_name == c.table: 309 downstream_columns = [] 310 311 column_name = c.name 312 if any(column_name == pivot_column.name for pivot_column in pivot_columns): 313 downstream_columns.extend(pivot_column_mapping[column_name]) 314 else: 315 # The column is not in the pivot, so it must be an implicit column of the 316 # pivoted source -- adapt column to be from the implicit pivoted source. 317 downstream_columns.append(exp.column(c.this, table=pivot.parent.alias_or_name)) 318 319 for downstream_column in downstream_columns: 320 table = downstream_column.table 321 source = scope.sources.get(table) 322 if isinstance(source, Scope): 323 to_node( 324 downstream_column.name, 325 scope=source, 326 scope_name=table, 327 dialect=dialect, 328 upstream=node, 329 source_name=source_names.get(table) or source_name, 330 reference_node_name=reference_node_name, 331 trim_selects=trim_selects, 332 ) 333 else: 334 source = source or exp.Placeholder() 335 node.downstream.append( 336 Node( 337 name=downstream_column.sql(comments=False), 338 source=source, 339 expression=source, 340 ) 341 ) 342 else: 343 # The source is not a scope and the column is not in any pivot - we've reached the end 344 # of the line. At this point, if a source is not found it means this column's lineage 345 # is unknown. This can happen if the definition of a source used in a query is not 346 # passed into the `sources` map. 347 source = source or exp.Placeholder() 348 node.downstream.append( 349 Node(name=c.sql(comments=False), source=source, expression=source) 350 ) 351 352 return node
class
GraphHTML:
355class GraphHTML: 356 """Node to HTML generator using vis.js. 357 358 https://visjs.github.io/vis-network/docs/network/ 359 """ 360 361 def __init__( 362 self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None 363 ): 364 self.imports = imports 365 366 self.options = { 367 "height": "500px", 368 "width": "100%", 369 "layout": { 370 "hierarchical": { 371 "enabled": True, 372 "nodeSpacing": 200, 373 "sortMethod": "directed", 374 }, 375 }, 376 "interaction": { 377 "dragNodes": False, 378 "selectable": False, 379 }, 380 "physics": { 381 "enabled": False, 382 }, 383 "edges": { 384 "arrows": "to", 385 }, 386 "nodes": { 387 "font": "20px monaco", 388 "shape": "box", 389 "widthConstraint": { 390 "maximum": 300, 391 }, 392 }, 393 **(options or {}), 394 } 395 396 self.nodes = nodes 397 self.edges = edges 398 399 def __str__(self): 400 nodes = json.dumps(list(self.nodes.values())) 401 edges = json.dumps(self.edges) 402 options = json.dumps(self.options) 403 imports = ( 404 """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script> 405 <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script> 406 <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />""" 407 if self.imports 408 else "" 409 ) 410 411 return f"""<div> 412 <div id="sqlglot-lineage"></div> 413 {imports} 414 <script type="text/javascript"> 415 var nodes = new vis.DataSet({nodes}) 416 nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0]) 417 418 new vis.Network( 419 document.getElementById("sqlglot-lineage"), 420 {{ 421 nodes: nodes, 422 edges: new vis.DataSet({edges}) 423 }}, 424 {options}, 425 ) 426 </script> 427</div>""" 428 429 def _repr_html_(self) -> str: 430 return self.__str__()
Node to HTML generator using vis.js.
GraphHTML( nodes: Dict, edges: List, imports: bool = True, options: Optional[Dict] = None)
361 def __init__( 362 self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None 363 ): 364 self.imports = imports 365 366 self.options = { 367 "height": "500px", 368 "width": "100%", 369 "layout": { 370 "hierarchical": { 371 "enabled": True, 372 "nodeSpacing": 200, 373 "sortMethod": "directed", 374 }, 375 }, 376 "interaction": { 377 "dragNodes": False, 378 "selectable": False, 379 }, 380 "physics": { 381 "enabled": False, 382 }, 383 "edges": { 384 "arrows": "to", 385 }, 386 "nodes": { 387 "font": "20px monaco", 388 "shape": "box", 389 "widthConstraint": { 390 "maximum": 300, 391 }, 392 }, 393 **(options or {}), 394 } 395 396 self.nodes = nodes 397 self.edges = edges