sqlglot.lineage
1from __future__ import annotations 2 3import json 4import logging 5import typing as t 6from dataclasses import dataclass, field 7 8from sqlglot import Schema, exp, maybe_parse 9from sqlglot.errors import SqlglotError 10from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify 11from sqlglot.optimizer.scope import ScopeType 12 13if t.TYPE_CHECKING: 14 from sqlglot.dialects.dialect import DialectType 15 16logger = logging.getLogger("sqlglot") 17 18 19@dataclass(frozen=True) 20class Node: 21 name: str 22 expression: exp.Expression 23 source: exp.Expression 24 downstream: t.List[Node] = field(default_factory=list) 25 source_name: str = "" 26 reference_node_name: str = "" 27 28 def walk(self) -> t.Iterator[Node]: 29 yield self 30 31 for d in self.downstream: 32 yield from d.walk() 33 34 def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: 35 nodes = {} 36 edges = [] 37 38 for node in self.walk(): 39 if isinstance(node.expression, exp.Table): 40 label = f"FROM {node.expression.this}" 41 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 42 group = 1 43 else: 44 label = node.expression.sql(pretty=True, dialect=dialect) 45 source = node.source.transform( 46 lambda n: ( 47 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 48 ), 49 copy=False, 50 ).sql(pretty=True, dialect=dialect) 51 title = f"<pre>{source}</pre>" 52 group = 0 53 54 node_id = id(node) 55 56 nodes[node_id] = { 57 "id": node_id, 58 "label": label, 59 "title": title, 60 "group": group, 61 } 62 63 for d in node.downstream: 64 edges.append({"from": node_id, "to": id(d)}) 65 return GraphHTML(nodes, edges, **opts) 66 67 68def lineage( 69 column: str | exp.Column, 70 sql: str | exp.Expression, 71 schema: t.Optional[t.Dict | Schema] = None, 72 sources: t.Optional[t.Mapping[str, str | exp.Query]] = None, 73 dialect: DialectType = None, 74 scope: t.Optional[Scope] = None, 75 trim_selects: bool = True, 76 **kwargs, 77) -> Node: 78 """Build the lineage graph for a column of a SQL query. 79 80 Args: 81 column: The column to build the lineage for. 82 sql: The SQL string or expression. 83 schema: The schema of tables. 84 sources: A mapping of queries which will be used to continue building lineage. 85 dialect: The dialect of input SQL. 86 scope: A pre-created scope to use instead. 87 trim_selects: Whether or not to clean up selects by trimming to only relevant columns. 88 **kwargs: Qualification optimizer kwargs. 89 90 Returns: 91 A lineage node. 92 """ 93 94 expression = maybe_parse(sql, dialect=dialect) 95 column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name 96 97 if sources: 98 expression = exp.expand( 99 expression, 100 {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()}, 101 dialect=dialect, 102 ) 103 104 if not scope: 105 expression = qualify.qualify( 106 expression, 107 dialect=dialect, 108 schema=schema, 109 **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore 110 ) 111 112 scope = build_scope(expression) 113 114 if not scope: 115 raise SqlglotError("Cannot build lineage, sql must be SELECT") 116 117 if not any(select.alias_or_name == column for select in scope.expression.selects): 118 raise SqlglotError(f"Cannot find column '{column}' in query.") 119 120 return to_node(column, scope, dialect, trim_selects=trim_selects) 121 122 123def to_node( 124 column: str | int, 125 scope: Scope, 126 dialect: DialectType, 127 scope_name: t.Optional[str] = None, 128 upstream: t.Optional[Node] = None, 129 source_name: t.Optional[str] = None, 130 reference_node_name: t.Optional[str] = None, 131 trim_selects: bool = True, 132) -> Node: 133 # Find the specific select clause that is the source of the column we want. 134 # This can either be a specific, named select or a generic `*` clause. 135 select = ( 136 scope.expression.selects[column] 137 if isinstance(column, int) 138 else next( 139 (select for select in scope.expression.selects if select.alias_or_name == column), 140 exp.Star() if scope.expression.is_star else scope.expression, 141 ) 142 ) 143 144 if isinstance(scope.expression, exp.Subquery): 145 for source in scope.subquery_scopes: 146 return to_node( 147 column, 148 scope=source, 149 dialect=dialect, 150 upstream=upstream, 151 source_name=source_name, 152 reference_node_name=reference_node_name, 153 trim_selects=trim_selects, 154 ) 155 if isinstance(scope.expression, exp.SetOperation): 156 name = type(scope.expression).__name__.upper() 157 upstream = upstream or Node(name=name, source=scope.expression, expression=select) 158 159 index = ( 160 column 161 if isinstance(column, int) 162 else next( 163 ( 164 i 165 for i, select in enumerate(scope.expression.selects) 166 if select.alias_or_name == column or select.is_star 167 ), 168 -1, # mypy will not allow a None here, but a negative index should never be returned 169 ) 170 ) 171 172 if index == -1: 173 raise ValueError(f"Could not find {column} in {scope.expression}") 174 175 for s in scope.union_scopes: 176 to_node( 177 index, 178 scope=s, 179 dialect=dialect, 180 upstream=upstream, 181 source_name=source_name, 182 reference_node_name=reference_node_name, 183 trim_selects=trim_selects, 184 ) 185 186 return upstream 187 188 if trim_selects and isinstance(scope.expression, exp.Select): 189 # For better ergonomics in our node labels, replace the full select with 190 # a version that has only the column we care about. 191 # "x", SELECT x, y FROM foo 192 # => "x", SELECT x FROM foo 193 source = t.cast(exp.Expression, scope.expression.select(select, append=False)) 194 else: 195 source = scope.expression 196 197 # Create the node for this step in the lineage chain, and attach it to the previous one. 198 node = Node( 199 name=f"{scope_name}.{column}" if scope_name else str(column), 200 source=source, 201 expression=select, 202 source_name=source_name or "", 203 reference_node_name=reference_node_name or "", 204 ) 205 206 if upstream: 207 upstream.downstream.append(node) 208 209 subquery_scopes = { 210 id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes 211 } 212 213 for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES): 214 subquery_scope = subquery_scopes.get(id(subquery)) 215 if not subquery_scope: 216 logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") 217 continue 218 219 for name in subquery.named_selects: 220 to_node( 221 name, 222 scope=subquery_scope, 223 dialect=dialect, 224 upstream=node, 225 trim_selects=trim_selects, 226 ) 227 228 # if the select is a star add all scope sources as downstreams 229 if select.is_star: 230 for source in scope.sources.values(): 231 if isinstance(source, Scope): 232 source = source.expression 233 node.downstream.append( 234 Node(name=select.sql(comments=False), source=source, expression=source) 235 ) 236 237 # Find all columns that went into creating this one to list their lineage nodes. 238 source_columns = set(find_all_in_scope(select, exp.Column)) 239 240 # If the source is a UDTF find columns used in the UTDF to generate the table 241 if isinstance(source, exp.UDTF): 242 source_columns |= set(source.find_all(exp.Column)) 243 derived_tables = [ 244 source.expression.parent 245 for source in scope.sources.values() 246 if isinstance(source, Scope) and source.is_derived_table 247 ] 248 else: 249 derived_tables = scope.derived_tables 250 251 source_names = { 252 dt.alias: dt.comments[0].split()[1] 253 for dt in derived_tables 254 if dt.comments and dt.comments[0].startswith("source: ") 255 } 256 257 pivots = scope.pivots 258 pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None 259 if pivot: 260 # For each aggregation function, the pivot creates a new column for each field in category 261 # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a, 262 # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum' 263 # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs 264 # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest 265 # in the lineage, so lookup the pivot column name by index and map that with the columns used 266 # in the aggregation. 267 # 268 # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b') 269 pivot_columns = pivot.args["columns"] 270 pivot_aggs_count = len(pivot.expressions) 271 272 pivot_column_mapping = {} 273 for i, agg in enumerate(pivot.expressions): 274 agg_cols = list(agg.find_all(exp.Column)) 275 for col_index in range(i, len(pivot_columns), pivot_aggs_count): 276 pivot_column_mapping[pivot_columns[col_index].name] = agg_cols 277 278 for c in source_columns: 279 table = c.table 280 source = scope.sources.get(table) 281 282 if isinstance(source, Scope): 283 reference_node_name = None 284 if source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names: 285 reference_node_name = table 286 elif source.scope_type == ScopeType.CTE: 287 selected_node, _ = scope.selected_sources.get(table, (None, None)) 288 reference_node_name = selected_node.name if selected_node else None 289 290 # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. 291 to_node( 292 c.name, 293 scope=source, 294 dialect=dialect, 295 scope_name=table, 296 upstream=node, 297 source_name=source_names.get(table) or source_name, 298 reference_node_name=reference_node_name, 299 trim_selects=trim_selects, 300 ) 301 elif pivot and pivot.alias_or_name == c.table: 302 downstream_columns = [] 303 304 column_name = c.name 305 if any(column_name == pivot_column.name for pivot_column in pivot_columns): 306 downstream_columns.extend(pivot_column_mapping[column_name]) 307 else: 308 # The column is not in the pivot, so it must be an implicit column of the 309 # pivoted source -- adapt column to be from the implicit pivoted source. 310 downstream_columns.append(exp.column(c.this, table=pivot.parent.this)) 311 312 for downstream_column in downstream_columns: 313 table = downstream_column.table 314 source = scope.sources.get(table) 315 if isinstance(source, Scope): 316 to_node( 317 downstream_column.name, 318 scope=source, 319 scope_name=table, 320 dialect=dialect, 321 upstream=node, 322 source_name=source_names.get(table) or source_name, 323 reference_node_name=reference_node_name, 324 trim_selects=trim_selects, 325 ) 326 else: 327 source = source or exp.Placeholder() 328 node.downstream.append( 329 Node( 330 name=downstream_column.sql(comments=False), 331 source=source, 332 expression=source, 333 ) 334 ) 335 else: 336 # The source is not a scope and the column is not in any pivot - we've reached the end 337 # of the line. At this point, if a source is not found it means this column's lineage 338 # is unknown. This can happen if the definition of a source used in a query is not 339 # passed into the `sources` map. 340 source = source or exp.Placeholder() 341 node.downstream.append( 342 Node(name=c.sql(comments=False), source=source, expression=source) 343 ) 344 345 return node 346 347 348class GraphHTML: 349 """Node to HTML generator using vis.js. 350 351 https://visjs.github.io/vis-network/docs/network/ 352 """ 353 354 def __init__( 355 self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None 356 ): 357 self.imports = imports 358 359 self.options = { 360 "height": "500px", 361 "width": "100%", 362 "layout": { 363 "hierarchical": { 364 "enabled": True, 365 "nodeSpacing": 200, 366 "sortMethod": "directed", 367 }, 368 }, 369 "interaction": { 370 "dragNodes": False, 371 "selectable": False, 372 }, 373 "physics": { 374 "enabled": False, 375 }, 376 "edges": { 377 "arrows": "to", 378 }, 379 "nodes": { 380 "font": "20px monaco", 381 "shape": "box", 382 "widthConstraint": { 383 "maximum": 300, 384 }, 385 }, 386 **(options or {}), 387 } 388 389 self.nodes = nodes 390 self.edges = edges 391 392 def __str__(self): 393 nodes = json.dumps(list(self.nodes.values())) 394 edges = json.dumps(self.edges) 395 options = json.dumps(self.options) 396 imports = ( 397 """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script> 398 <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script> 399 <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />""" 400 if self.imports 401 else "" 402 ) 403 404 return f"""<div> 405 <div id="sqlglot-lineage"></div> 406 {imports} 407 <script type="text/javascript"> 408 var nodes = new vis.DataSet({nodes}) 409 nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0]) 410 411 new vis.Network( 412 document.getElementById("sqlglot-lineage"), 413 {{ 414 nodes: nodes, 415 edges: new vis.DataSet({edges}) 416 }}, 417 {options}, 418 ) 419 </script> 420</div>""" 421 422 def _repr_html_(self) -> str: 423 return self.__str__()
logger =
<Logger sqlglot (WARNING)>
@dataclass(frozen=True)
class
Node:
20@dataclass(frozen=True) 21class Node: 22 name: str 23 expression: exp.Expression 24 source: exp.Expression 25 downstream: t.List[Node] = field(default_factory=list) 26 source_name: str = "" 27 reference_node_name: str = "" 28 29 def walk(self) -> t.Iterator[Node]: 30 yield self 31 32 for d in self.downstream: 33 yield from d.walk() 34 35 def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: 36 nodes = {} 37 edges = [] 38 39 for node in self.walk(): 40 if isinstance(node.expression, exp.Table): 41 label = f"FROM {node.expression.this}" 42 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 43 group = 1 44 else: 45 label = node.expression.sql(pretty=True, dialect=dialect) 46 source = node.source.transform( 47 lambda n: ( 48 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 49 ), 50 copy=False, 51 ).sql(pretty=True, dialect=dialect) 52 title = f"<pre>{source}</pre>" 53 group = 0 54 55 node_id = id(node) 56 57 nodes[node_id] = { 58 "id": node_id, 59 "label": label, 60 "title": title, 61 "group": group, 62 } 63 64 for d in node.downstream: 65 edges.append({"from": node_id, "to": id(d)}) 66 return GraphHTML(nodes, edges, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[Node] = <factory>, source_name: str = '', reference_node_name: str = '')
expression: sqlglot.expressions.Expression
source: sqlglot.expressions.Expression
downstream: List[Node]
def
to_html( self, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **opts) -> GraphHTML:
35 def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: 36 nodes = {} 37 edges = [] 38 39 for node in self.walk(): 40 if isinstance(node.expression, exp.Table): 41 label = f"FROM {node.expression.this}" 42 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 43 group = 1 44 else: 45 label = node.expression.sql(pretty=True, dialect=dialect) 46 source = node.source.transform( 47 lambda n: ( 48 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 49 ), 50 copy=False, 51 ).sql(pretty=True, dialect=dialect) 52 title = f"<pre>{source}</pre>" 53 group = 0 54 55 node_id = id(node) 56 57 nodes[node_id] = { 58 "id": node_id, 59 "label": label, 60 "title": title, 61 "group": group, 62 } 63 64 for d in node.downstream: 65 edges.append({"from": node_id, "to": id(d)}) 66 return GraphHTML(nodes, edges, **opts)
def
lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Mapping[str, str | sqlglot.expressions.Query]] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, scope: Optional[sqlglot.optimizer.scope.Scope] = None, trim_selects: bool = True, **kwargs) -> Node:
69def lineage( 70 column: str | exp.Column, 71 sql: str | exp.Expression, 72 schema: t.Optional[t.Dict | Schema] = None, 73 sources: t.Optional[t.Mapping[str, str | exp.Query]] = None, 74 dialect: DialectType = None, 75 scope: t.Optional[Scope] = None, 76 trim_selects: bool = True, 77 **kwargs, 78) -> Node: 79 """Build the lineage graph for a column of a SQL query. 80 81 Args: 82 column: The column to build the lineage for. 83 sql: The SQL string or expression. 84 schema: The schema of tables. 85 sources: A mapping of queries which will be used to continue building lineage. 86 dialect: The dialect of input SQL. 87 scope: A pre-created scope to use instead. 88 trim_selects: Whether or not to clean up selects by trimming to only relevant columns. 89 **kwargs: Qualification optimizer kwargs. 90 91 Returns: 92 A lineage node. 93 """ 94 95 expression = maybe_parse(sql, dialect=dialect) 96 column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name 97 98 if sources: 99 expression = exp.expand( 100 expression, 101 {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()}, 102 dialect=dialect, 103 ) 104 105 if not scope: 106 expression = qualify.qualify( 107 expression, 108 dialect=dialect, 109 schema=schema, 110 **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore 111 ) 112 113 scope = build_scope(expression) 114 115 if not scope: 116 raise SqlglotError("Cannot build lineage, sql must be SELECT") 117 118 if not any(select.alias_or_name == column for select in scope.expression.selects): 119 raise SqlglotError(f"Cannot find column '{column}' in query.") 120 121 return to_node(column, scope, dialect, trim_selects=trim_selects)
Build the lineage graph for a column of a SQL query.
Arguments:
- column: The column to build the lineage for.
- sql: The SQL string or expression.
- schema: The schema of tables.
- sources: A mapping of queries which will be used to continue building lineage.
- dialect: The dialect of input SQL.
- scope: A pre-created scope to use instead.
- trim_selects: Whether or not to clean up selects by trimming to only relevant columns.
- **kwargs: Qualification optimizer kwargs.
Returns:
A lineage node.
def
to_node( column: str | int, scope: sqlglot.optimizer.scope.Scope, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType], scope_name: Optional[str] = None, upstream: Optional[Node] = None, source_name: Optional[str] = None, reference_node_name: Optional[str] = None, trim_selects: bool = True) -> Node:
124def to_node( 125 column: str | int, 126 scope: Scope, 127 dialect: DialectType, 128 scope_name: t.Optional[str] = None, 129 upstream: t.Optional[Node] = None, 130 source_name: t.Optional[str] = None, 131 reference_node_name: t.Optional[str] = None, 132 trim_selects: bool = True, 133) -> Node: 134 # Find the specific select clause that is the source of the column we want. 135 # This can either be a specific, named select or a generic `*` clause. 136 select = ( 137 scope.expression.selects[column] 138 if isinstance(column, int) 139 else next( 140 (select for select in scope.expression.selects if select.alias_or_name == column), 141 exp.Star() if scope.expression.is_star else scope.expression, 142 ) 143 ) 144 145 if isinstance(scope.expression, exp.Subquery): 146 for source in scope.subquery_scopes: 147 return to_node( 148 column, 149 scope=source, 150 dialect=dialect, 151 upstream=upstream, 152 source_name=source_name, 153 reference_node_name=reference_node_name, 154 trim_selects=trim_selects, 155 ) 156 if isinstance(scope.expression, exp.SetOperation): 157 name = type(scope.expression).__name__.upper() 158 upstream = upstream or Node(name=name, source=scope.expression, expression=select) 159 160 index = ( 161 column 162 if isinstance(column, int) 163 else next( 164 ( 165 i 166 for i, select in enumerate(scope.expression.selects) 167 if select.alias_or_name == column or select.is_star 168 ), 169 -1, # mypy will not allow a None here, but a negative index should never be returned 170 ) 171 ) 172 173 if index == -1: 174 raise ValueError(f"Could not find {column} in {scope.expression}") 175 176 for s in scope.union_scopes: 177 to_node( 178 index, 179 scope=s, 180 dialect=dialect, 181 upstream=upstream, 182 source_name=source_name, 183 reference_node_name=reference_node_name, 184 trim_selects=trim_selects, 185 ) 186 187 return upstream 188 189 if trim_selects and isinstance(scope.expression, exp.Select): 190 # For better ergonomics in our node labels, replace the full select with 191 # a version that has only the column we care about. 192 # "x", SELECT x, y FROM foo 193 # => "x", SELECT x FROM foo 194 source = t.cast(exp.Expression, scope.expression.select(select, append=False)) 195 else: 196 source = scope.expression 197 198 # Create the node for this step in the lineage chain, and attach it to the previous one. 199 node = Node( 200 name=f"{scope_name}.{column}" if scope_name else str(column), 201 source=source, 202 expression=select, 203 source_name=source_name or "", 204 reference_node_name=reference_node_name or "", 205 ) 206 207 if upstream: 208 upstream.downstream.append(node) 209 210 subquery_scopes = { 211 id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes 212 } 213 214 for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES): 215 subquery_scope = subquery_scopes.get(id(subquery)) 216 if not subquery_scope: 217 logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") 218 continue 219 220 for name in subquery.named_selects: 221 to_node( 222 name, 223 scope=subquery_scope, 224 dialect=dialect, 225 upstream=node, 226 trim_selects=trim_selects, 227 ) 228 229 # if the select is a star add all scope sources as downstreams 230 if select.is_star: 231 for source in scope.sources.values(): 232 if isinstance(source, Scope): 233 source = source.expression 234 node.downstream.append( 235 Node(name=select.sql(comments=False), source=source, expression=source) 236 ) 237 238 # Find all columns that went into creating this one to list their lineage nodes. 239 source_columns = set(find_all_in_scope(select, exp.Column)) 240 241 # If the source is a UDTF find columns used in the UTDF to generate the table 242 if isinstance(source, exp.UDTF): 243 source_columns |= set(source.find_all(exp.Column)) 244 derived_tables = [ 245 source.expression.parent 246 for source in scope.sources.values() 247 if isinstance(source, Scope) and source.is_derived_table 248 ] 249 else: 250 derived_tables = scope.derived_tables 251 252 source_names = { 253 dt.alias: dt.comments[0].split()[1] 254 for dt in derived_tables 255 if dt.comments and dt.comments[0].startswith("source: ") 256 } 257 258 pivots = scope.pivots 259 pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None 260 if pivot: 261 # For each aggregation function, the pivot creates a new column for each field in category 262 # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a, 263 # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum' 264 # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs 265 # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest 266 # in the lineage, so lookup the pivot column name by index and map that with the columns used 267 # in the aggregation. 268 # 269 # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b') 270 pivot_columns = pivot.args["columns"] 271 pivot_aggs_count = len(pivot.expressions) 272 273 pivot_column_mapping = {} 274 for i, agg in enumerate(pivot.expressions): 275 agg_cols = list(agg.find_all(exp.Column)) 276 for col_index in range(i, len(pivot_columns), pivot_aggs_count): 277 pivot_column_mapping[pivot_columns[col_index].name] = agg_cols 278 279 for c in source_columns: 280 table = c.table 281 source = scope.sources.get(table) 282 283 if isinstance(source, Scope): 284 reference_node_name = None 285 if source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names: 286 reference_node_name = table 287 elif source.scope_type == ScopeType.CTE: 288 selected_node, _ = scope.selected_sources.get(table, (None, None)) 289 reference_node_name = selected_node.name if selected_node else None 290 291 # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. 292 to_node( 293 c.name, 294 scope=source, 295 dialect=dialect, 296 scope_name=table, 297 upstream=node, 298 source_name=source_names.get(table) or source_name, 299 reference_node_name=reference_node_name, 300 trim_selects=trim_selects, 301 ) 302 elif pivot and pivot.alias_or_name == c.table: 303 downstream_columns = [] 304 305 column_name = c.name 306 if any(column_name == pivot_column.name for pivot_column in pivot_columns): 307 downstream_columns.extend(pivot_column_mapping[column_name]) 308 else: 309 # The column is not in the pivot, so it must be an implicit column of the 310 # pivoted source -- adapt column to be from the implicit pivoted source. 311 downstream_columns.append(exp.column(c.this, table=pivot.parent.this)) 312 313 for downstream_column in downstream_columns: 314 table = downstream_column.table 315 source = scope.sources.get(table) 316 if isinstance(source, Scope): 317 to_node( 318 downstream_column.name, 319 scope=source, 320 scope_name=table, 321 dialect=dialect, 322 upstream=node, 323 source_name=source_names.get(table) or source_name, 324 reference_node_name=reference_node_name, 325 trim_selects=trim_selects, 326 ) 327 else: 328 source = source or exp.Placeholder() 329 node.downstream.append( 330 Node( 331 name=downstream_column.sql(comments=False), 332 source=source, 333 expression=source, 334 ) 335 ) 336 else: 337 # The source is not a scope and the column is not in any pivot - we've reached the end 338 # of the line. At this point, if a source is not found it means this column's lineage 339 # is unknown. This can happen if the definition of a source used in a query is not 340 # passed into the `sources` map. 341 source = source or exp.Placeholder() 342 node.downstream.append( 343 Node(name=c.sql(comments=False), source=source, expression=source) 344 ) 345 346 return node
class
GraphHTML:
349class GraphHTML: 350 """Node to HTML generator using vis.js. 351 352 https://visjs.github.io/vis-network/docs/network/ 353 """ 354 355 def __init__( 356 self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None 357 ): 358 self.imports = imports 359 360 self.options = { 361 "height": "500px", 362 "width": "100%", 363 "layout": { 364 "hierarchical": { 365 "enabled": True, 366 "nodeSpacing": 200, 367 "sortMethod": "directed", 368 }, 369 }, 370 "interaction": { 371 "dragNodes": False, 372 "selectable": False, 373 }, 374 "physics": { 375 "enabled": False, 376 }, 377 "edges": { 378 "arrows": "to", 379 }, 380 "nodes": { 381 "font": "20px monaco", 382 "shape": "box", 383 "widthConstraint": { 384 "maximum": 300, 385 }, 386 }, 387 **(options or {}), 388 } 389 390 self.nodes = nodes 391 self.edges = edges 392 393 def __str__(self): 394 nodes = json.dumps(list(self.nodes.values())) 395 edges = json.dumps(self.edges) 396 options = json.dumps(self.options) 397 imports = ( 398 """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script> 399 <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script> 400 <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />""" 401 if self.imports 402 else "" 403 ) 404 405 return f"""<div> 406 <div id="sqlglot-lineage"></div> 407 {imports} 408 <script type="text/javascript"> 409 var nodes = new vis.DataSet({nodes}) 410 nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0]) 411 412 new vis.Network( 413 document.getElementById("sqlglot-lineage"), 414 {{ 415 nodes: nodes, 416 edges: new vis.DataSet({edges}) 417 }}, 418 {options}, 419 ) 420 </script> 421</div>""" 422 423 def _repr_html_(self) -> str: 424 return self.__str__()
Node to HTML generator using vis.js.
GraphHTML( nodes: Dict, edges: List, imports: bool = True, options: Optional[Dict] = None)
355 def __init__( 356 self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None 357 ): 358 self.imports = imports 359 360 self.options = { 361 "height": "500px", 362 "width": "100%", 363 "layout": { 364 "hierarchical": { 365 "enabled": True, 366 "nodeSpacing": 200, 367 "sortMethod": "directed", 368 }, 369 }, 370 "interaction": { 371 "dragNodes": False, 372 "selectable": False, 373 }, 374 "physics": { 375 "enabled": False, 376 }, 377 "edges": { 378 "arrows": "to", 379 }, 380 "nodes": { 381 "font": "20px monaco", 382 "shape": "box", 383 "widthConstraint": { 384 "maximum": 300, 385 }, 386 }, 387 **(options or {}), 388 } 389 390 self.nodes = nodes 391 self.edges = edges