Edit on GitHub

sqlglot.lineage

  1from __future__ import annotations
  2
  3import json
  4import logging
  5import typing as t
  6from dataclasses import dataclass, field
  7
  8from sqlglot import Schema, exp, maybe_parse
  9from sqlglot.errors import SqlglotError
 10from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify
 11from sqlglot.optimizer.scope import ScopeType
 12
 13if t.TYPE_CHECKING:
 14    from sqlglot.dialects.dialect import DialectType
 15
 16logger = logging.getLogger("sqlglot")
 17
 18
 19@dataclass(frozen=True)
 20class Node:
 21    name: str
 22    expression: exp.Expression
 23    source: exp.Expression
 24    downstream: t.List[Node] = field(default_factory=list)
 25    source_name: str = ""
 26    reference_node_name: str = ""
 27
 28    def walk(self) -> t.Iterator[Node]:
 29        yield self
 30
 31        for d in self.downstream:
 32            yield from d.walk()
 33
 34    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
 35        nodes = {}
 36        edges = []
 37
 38        for node in self.walk():
 39            if isinstance(node.expression, exp.Table):
 40                label = f"FROM {node.expression.this}"
 41                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
 42                group = 1
 43            else:
 44                label = node.expression.sql(pretty=True, dialect=dialect)
 45                source = node.source.transform(
 46                    lambda n: (
 47                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
 48                    ),
 49                    copy=False,
 50                ).sql(pretty=True, dialect=dialect)
 51                title = f"<pre>{source}</pre>"
 52                group = 0
 53
 54            node_id = id(node)
 55
 56            nodes[node_id] = {
 57                "id": node_id,
 58                "label": label,
 59                "title": title,
 60                "group": group,
 61            }
 62
 63            for d in node.downstream:
 64                edges.append({"from": node_id, "to": id(d)})
 65        return GraphHTML(nodes, edges, **opts)
 66
 67
 68def lineage(
 69    column: str | exp.Column,
 70    sql: str | exp.Expression,
 71    schema: t.Optional[t.Dict | Schema] = None,
 72    sources: t.Optional[t.Mapping[str, str | exp.Query]] = None,
 73    dialect: DialectType = None,
 74    scope: t.Optional[Scope] = None,
 75    trim_selects: bool = True,
 76    **kwargs,
 77) -> Node:
 78    """Build the lineage graph for a column of a SQL query.
 79
 80    Args:
 81        column: The column to build the lineage for.
 82        sql: The SQL string or expression.
 83        schema: The schema of tables.
 84        sources: A mapping of queries which will be used to continue building lineage.
 85        dialect: The dialect of input SQL.
 86        scope: A pre-created scope to use instead.
 87        trim_selects: Whether or not to clean up selects by trimming to only relevant columns.
 88        **kwargs: Qualification optimizer kwargs.
 89
 90    Returns:
 91        A lineage node.
 92    """
 93
 94    expression = maybe_parse(sql, dialect=dialect)
 95    column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
 96
 97    if sources:
 98        expression = exp.expand(
 99            expression,
100            {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()},
101            dialect=dialect,
102        )
103
104    if not scope:
105        expression = qualify.qualify(
106            expression,
107            dialect=dialect,
108            schema=schema,
109            **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
110        )
111
112        scope = build_scope(expression)
113
114    if not scope:
115        raise SqlglotError("Cannot build lineage, sql must be SELECT")
116
117    if not any(select.alias_or_name == column for select in scope.expression.selects):
118        raise SqlglotError(f"Cannot find column '{column}' in query.")
119
120    return to_node(column, scope, dialect, trim_selects=trim_selects)
121
122
123def to_node(
124    column: str | int,
125    scope: Scope,
126    dialect: DialectType,
127    scope_name: t.Optional[str] = None,
128    upstream: t.Optional[Node] = None,
129    source_name: t.Optional[str] = None,
130    reference_node_name: t.Optional[str] = None,
131    trim_selects: bool = True,
132) -> Node:
133    # Find the specific select clause that is the source of the column we want.
134    # This can either be a specific, named select or a generic `*` clause.
135    select = (
136        scope.expression.selects[column]
137        if isinstance(column, int)
138        else next(
139            (select for select in scope.expression.selects if select.alias_or_name == column),
140            exp.Star() if scope.expression.is_star else scope.expression,
141        )
142    )
143
144    if isinstance(scope.expression, exp.Subquery):
145        for source in scope.subquery_scopes:
146            return to_node(
147                column,
148                scope=source,
149                dialect=dialect,
150                upstream=upstream,
151                source_name=source_name,
152                reference_node_name=reference_node_name,
153                trim_selects=trim_selects,
154            )
155    if isinstance(scope.expression, exp.SetOperation):
156        name = type(scope.expression).__name__.upper()
157        upstream = upstream or Node(name=name, source=scope.expression, expression=select)
158
159        index = (
160            column
161            if isinstance(column, int)
162            else next(
163                (
164                    i
165                    for i, select in enumerate(scope.expression.selects)
166                    if select.alias_or_name == column or select.is_star
167                ),
168                -1,  # mypy will not allow a None here, but a negative index should never be returned
169            )
170        )
171
172        if index == -1:
173            raise ValueError(f"Could not find {column} in {scope.expression}")
174
175        for s in scope.union_scopes:
176            to_node(
177                index,
178                scope=s,
179                dialect=dialect,
180                upstream=upstream,
181                source_name=source_name,
182                reference_node_name=reference_node_name,
183                trim_selects=trim_selects,
184            )
185
186        return upstream
187
188    if trim_selects and isinstance(scope.expression, exp.Select):
189        # For better ergonomics in our node labels, replace the full select with
190        # a version that has only the column we care about.
191        #   "x", SELECT x, y FROM foo
192        #     => "x", SELECT x FROM foo
193        source = t.cast(exp.Expression, scope.expression.select(select, append=False))
194    else:
195        source = scope.expression
196
197    # Create the node for this step in the lineage chain, and attach it to the previous one.
198    node = Node(
199        name=f"{scope_name}.{column}" if scope_name else str(column),
200        source=source,
201        expression=select,
202        source_name=source_name or "",
203        reference_node_name=reference_node_name or "",
204    )
205
206    if upstream:
207        upstream.downstream.append(node)
208
209    subquery_scopes = {
210        id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes
211    }
212
213    for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES):
214        subquery_scope = subquery_scopes.get(id(subquery))
215        if not subquery_scope:
216            logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
217            continue
218
219        for name in subquery.named_selects:
220            to_node(
221                name,
222                scope=subquery_scope,
223                dialect=dialect,
224                upstream=node,
225                trim_selects=trim_selects,
226            )
227
228    # if the select is a star add all scope sources as downstreams
229    if select.is_star:
230        for source in scope.sources.values():
231            if isinstance(source, Scope):
232                source = source.expression
233            node.downstream.append(
234                Node(name=select.sql(comments=False), source=source, expression=source)
235            )
236
237    # Find all columns that went into creating this one to list their lineage nodes.
238    source_columns = set(find_all_in_scope(select, exp.Column))
239
240    # If the source is a UDTF find columns used in the UTDF to generate the table
241    if isinstance(source, exp.UDTF):
242        source_columns |= set(source.find_all(exp.Column))
243        derived_tables = [
244            source.expression.parent
245            for source in scope.sources.values()
246            if isinstance(source, Scope) and source.is_derived_table
247        ]
248    else:
249        derived_tables = scope.derived_tables
250
251    source_names = {
252        dt.alias: dt.comments[0].split()[1]
253        for dt in derived_tables
254        if dt.comments and dt.comments[0].startswith("source: ")
255    }
256
257    pivots = scope.pivots
258    pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None
259    if pivot:
260        # For each aggregation function, the pivot creates a new column for each field in category
261        # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a,
262        # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum'
263        # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs
264        # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest
265        # in the lineage, so lookup the pivot column name by index and map that with the columns used
266        # in the aggregation.
267        #
268        # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b')
269        pivot_columns = pivot.args["columns"]
270        pivot_aggs_count = len(pivot.expressions)
271
272        pivot_column_mapping = {}
273        for i, agg in enumerate(pivot.expressions):
274            agg_cols = list(agg.find_all(exp.Column))
275            for col_index in range(i, len(pivot_columns), pivot_aggs_count):
276                pivot_column_mapping[pivot_columns[col_index].name] = agg_cols
277
278    for c in source_columns:
279        table = c.table
280        source = scope.sources.get(table)
281
282        if isinstance(source, Scope):
283            reference_node_name = None
284            if source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names:
285                reference_node_name = table
286            elif source.scope_type == ScopeType.CTE:
287                selected_node, _ = scope.selected_sources.get(table, (None, None))
288                reference_node_name = selected_node.name if selected_node else None
289
290            # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
291            to_node(
292                c.name,
293                scope=source,
294                dialect=dialect,
295                scope_name=table,
296                upstream=node,
297                source_name=source_names.get(table) or source_name,
298                reference_node_name=reference_node_name,
299                trim_selects=trim_selects,
300            )
301        elif pivot and pivot.alias_or_name == c.table:
302            downstream_columns = []
303
304            column_name = c.name
305            if any(column_name == pivot_column.name for pivot_column in pivot_columns):
306                downstream_columns.extend(pivot_column_mapping[column_name])
307            else:
308                # The column is not in the pivot, so it must be an implicit column of the
309                # pivoted source -- adapt column to be from the implicit pivoted source.
310                downstream_columns.append(exp.column(c.this, table=pivot.parent.this))
311
312            for downstream_column in downstream_columns:
313                table = downstream_column.table
314                source = scope.sources.get(table)
315                if isinstance(source, Scope):
316                    to_node(
317                        downstream_column.name,
318                        scope=source,
319                        scope_name=table,
320                        dialect=dialect,
321                        upstream=node,
322                        source_name=source_names.get(table) or source_name,
323                        reference_node_name=reference_node_name,
324                        trim_selects=trim_selects,
325                    )
326                else:
327                    source = source or exp.Placeholder()
328                    node.downstream.append(
329                        Node(
330                            name=downstream_column.sql(comments=False),
331                            source=source,
332                            expression=source,
333                        )
334                    )
335        else:
336            # The source is not a scope and the column is not in any pivot - we've reached the end
337            # of the line. At this point, if a source is not found it means this column's lineage
338            # is unknown. This can happen if the definition of a source used in a query is not
339            # passed into the `sources` map.
340            source = source or exp.Placeholder()
341            node.downstream.append(
342                Node(name=c.sql(comments=False), source=source, expression=source)
343            )
344
345    return node
346
347
348class GraphHTML:
349    """Node to HTML generator using vis.js.
350
351    https://visjs.github.io/vis-network/docs/network/
352    """
353
354    def __init__(
355        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
356    ):
357        self.imports = imports
358
359        self.options = {
360            "height": "500px",
361            "width": "100%",
362            "layout": {
363                "hierarchical": {
364                    "enabled": True,
365                    "nodeSpacing": 200,
366                    "sortMethod": "directed",
367                },
368            },
369            "interaction": {
370                "dragNodes": False,
371                "selectable": False,
372            },
373            "physics": {
374                "enabled": False,
375            },
376            "edges": {
377                "arrows": "to",
378            },
379            "nodes": {
380                "font": "20px monaco",
381                "shape": "box",
382                "widthConstraint": {
383                    "maximum": 300,
384                },
385            },
386            **(options or {}),
387        }
388
389        self.nodes = nodes
390        self.edges = edges
391
392    def __str__(self):
393        nodes = json.dumps(list(self.nodes.values()))
394        edges = json.dumps(self.edges)
395        options = json.dumps(self.options)
396        imports = (
397            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
398  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
399  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
400            if self.imports
401            else ""
402        )
403
404        return f"""<div>
405  <div id="sqlglot-lineage"></div>
406  {imports}
407  <script type="text/javascript">
408    var nodes = new vis.DataSet({nodes})
409    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
410
411    new vis.Network(
412        document.getElementById("sqlglot-lineage"),
413        {{
414            nodes: nodes,
415            edges: new vis.DataSet({edges})
416        }},
417        {options},
418    )
419  </script>
420</div>"""
421
422    def _repr_html_(self) -> str:
423        return self.__str__()
logger = <Logger sqlglot (WARNING)>
@dataclass(frozen=True)
class Node:
20@dataclass(frozen=True)
21class Node:
22    name: str
23    expression: exp.Expression
24    source: exp.Expression
25    downstream: t.List[Node] = field(default_factory=list)
26    source_name: str = ""
27    reference_node_name: str = ""
28
29    def walk(self) -> t.Iterator[Node]:
30        yield self
31
32        for d in self.downstream:
33            yield from d.walk()
34
35    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
36        nodes = {}
37        edges = []
38
39        for node in self.walk():
40            if isinstance(node.expression, exp.Table):
41                label = f"FROM {node.expression.this}"
42                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
43                group = 1
44            else:
45                label = node.expression.sql(pretty=True, dialect=dialect)
46                source = node.source.transform(
47                    lambda n: (
48                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
49                    ),
50                    copy=False,
51                ).sql(pretty=True, dialect=dialect)
52                title = f"<pre>{source}</pre>"
53                group = 0
54
55            node_id = id(node)
56
57            nodes[node_id] = {
58                "id": node_id,
59                "label": label,
60                "title": title,
61                "group": group,
62            }
63
64            for d in node.downstream:
65                edges.append({"from": node_id, "to": id(d)})
66        return GraphHTML(nodes, edges, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[Node] = <factory>, source_name: str = '', reference_node_name: str = '')
name: str
downstream: List[Node]
source_name: str = ''
reference_node_name: str = ''
def walk(self) -> Iterator[Node]:
29    def walk(self) -> t.Iterator[Node]:
30        yield self
31
32        for d in self.downstream:
33            yield from d.walk()
def to_html( self, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **opts) -> GraphHTML:
35    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
36        nodes = {}
37        edges = []
38
39        for node in self.walk():
40            if isinstance(node.expression, exp.Table):
41                label = f"FROM {node.expression.this}"
42                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
43                group = 1
44            else:
45                label = node.expression.sql(pretty=True, dialect=dialect)
46                source = node.source.transform(
47                    lambda n: (
48                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
49                    ),
50                    copy=False,
51                ).sql(pretty=True, dialect=dialect)
52                title = f"<pre>{source}</pre>"
53                group = 0
54
55            node_id = id(node)
56
57            nodes[node_id] = {
58                "id": node_id,
59                "label": label,
60                "title": title,
61                "group": group,
62            }
63
64            for d in node.downstream:
65                edges.append({"from": node_id, "to": id(d)})
66        return GraphHTML(nodes, edges, **opts)
def lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Mapping[str, str | sqlglot.expressions.Query]] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, scope: Optional[sqlglot.optimizer.scope.Scope] = None, trim_selects: bool = True, **kwargs) -> Node:
 69def lineage(
 70    column: str | exp.Column,
 71    sql: str | exp.Expression,
 72    schema: t.Optional[t.Dict | Schema] = None,
 73    sources: t.Optional[t.Mapping[str, str | exp.Query]] = None,
 74    dialect: DialectType = None,
 75    scope: t.Optional[Scope] = None,
 76    trim_selects: bool = True,
 77    **kwargs,
 78) -> Node:
 79    """Build the lineage graph for a column of a SQL query.
 80
 81    Args:
 82        column: The column to build the lineage for.
 83        sql: The SQL string or expression.
 84        schema: The schema of tables.
 85        sources: A mapping of queries which will be used to continue building lineage.
 86        dialect: The dialect of input SQL.
 87        scope: A pre-created scope to use instead.
 88        trim_selects: Whether or not to clean up selects by trimming to only relevant columns.
 89        **kwargs: Qualification optimizer kwargs.
 90
 91    Returns:
 92        A lineage node.
 93    """
 94
 95    expression = maybe_parse(sql, dialect=dialect)
 96    column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
 97
 98    if sources:
 99        expression = exp.expand(
100            expression,
101            {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()},
102            dialect=dialect,
103        )
104
105    if not scope:
106        expression = qualify.qualify(
107            expression,
108            dialect=dialect,
109            schema=schema,
110            **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
111        )
112
113        scope = build_scope(expression)
114
115    if not scope:
116        raise SqlglotError("Cannot build lineage, sql must be SELECT")
117
118    if not any(select.alias_or_name == column for select in scope.expression.selects):
119        raise SqlglotError(f"Cannot find column '{column}' in query.")
120
121    return to_node(column, scope, dialect, trim_selects=trim_selects)

Build the lineage graph for a column of a SQL query.

Arguments:
  • column: The column to build the lineage for.
  • sql: The SQL string or expression.
  • schema: The schema of tables.
  • sources: A mapping of queries which will be used to continue building lineage.
  • dialect: The dialect of input SQL.
  • scope: A pre-created scope to use instead.
  • trim_selects: Whether or not to clean up selects by trimming to only relevant columns.
  • **kwargs: Qualification optimizer kwargs.
Returns:

A lineage node.

def to_node( column: str | int, scope: sqlglot.optimizer.scope.Scope, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType], scope_name: Optional[str] = None, upstream: Optional[Node] = None, source_name: Optional[str] = None, reference_node_name: Optional[str] = None, trim_selects: bool = True) -> Node:
124def to_node(
125    column: str | int,
126    scope: Scope,
127    dialect: DialectType,
128    scope_name: t.Optional[str] = None,
129    upstream: t.Optional[Node] = None,
130    source_name: t.Optional[str] = None,
131    reference_node_name: t.Optional[str] = None,
132    trim_selects: bool = True,
133) -> Node:
134    # Find the specific select clause that is the source of the column we want.
135    # This can either be a specific, named select or a generic `*` clause.
136    select = (
137        scope.expression.selects[column]
138        if isinstance(column, int)
139        else next(
140            (select for select in scope.expression.selects if select.alias_or_name == column),
141            exp.Star() if scope.expression.is_star else scope.expression,
142        )
143    )
144
145    if isinstance(scope.expression, exp.Subquery):
146        for source in scope.subquery_scopes:
147            return to_node(
148                column,
149                scope=source,
150                dialect=dialect,
151                upstream=upstream,
152                source_name=source_name,
153                reference_node_name=reference_node_name,
154                trim_selects=trim_selects,
155            )
156    if isinstance(scope.expression, exp.SetOperation):
157        name = type(scope.expression).__name__.upper()
158        upstream = upstream or Node(name=name, source=scope.expression, expression=select)
159
160        index = (
161            column
162            if isinstance(column, int)
163            else next(
164                (
165                    i
166                    for i, select in enumerate(scope.expression.selects)
167                    if select.alias_or_name == column or select.is_star
168                ),
169                -1,  # mypy will not allow a None here, but a negative index should never be returned
170            )
171        )
172
173        if index == -1:
174            raise ValueError(f"Could not find {column} in {scope.expression}")
175
176        for s in scope.union_scopes:
177            to_node(
178                index,
179                scope=s,
180                dialect=dialect,
181                upstream=upstream,
182                source_name=source_name,
183                reference_node_name=reference_node_name,
184                trim_selects=trim_selects,
185            )
186
187        return upstream
188
189    if trim_selects and isinstance(scope.expression, exp.Select):
190        # For better ergonomics in our node labels, replace the full select with
191        # a version that has only the column we care about.
192        #   "x", SELECT x, y FROM foo
193        #     => "x", SELECT x FROM foo
194        source = t.cast(exp.Expression, scope.expression.select(select, append=False))
195    else:
196        source = scope.expression
197
198    # Create the node for this step in the lineage chain, and attach it to the previous one.
199    node = Node(
200        name=f"{scope_name}.{column}" if scope_name else str(column),
201        source=source,
202        expression=select,
203        source_name=source_name or "",
204        reference_node_name=reference_node_name or "",
205    )
206
207    if upstream:
208        upstream.downstream.append(node)
209
210    subquery_scopes = {
211        id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes
212    }
213
214    for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES):
215        subquery_scope = subquery_scopes.get(id(subquery))
216        if not subquery_scope:
217            logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
218            continue
219
220        for name in subquery.named_selects:
221            to_node(
222                name,
223                scope=subquery_scope,
224                dialect=dialect,
225                upstream=node,
226                trim_selects=trim_selects,
227            )
228
229    # if the select is a star add all scope sources as downstreams
230    if select.is_star:
231        for source in scope.sources.values():
232            if isinstance(source, Scope):
233                source = source.expression
234            node.downstream.append(
235                Node(name=select.sql(comments=False), source=source, expression=source)
236            )
237
238    # Find all columns that went into creating this one to list their lineage nodes.
239    source_columns = set(find_all_in_scope(select, exp.Column))
240
241    # If the source is a UDTF find columns used in the UTDF to generate the table
242    if isinstance(source, exp.UDTF):
243        source_columns |= set(source.find_all(exp.Column))
244        derived_tables = [
245            source.expression.parent
246            for source in scope.sources.values()
247            if isinstance(source, Scope) and source.is_derived_table
248        ]
249    else:
250        derived_tables = scope.derived_tables
251
252    source_names = {
253        dt.alias: dt.comments[0].split()[1]
254        for dt in derived_tables
255        if dt.comments and dt.comments[0].startswith("source: ")
256    }
257
258    pivots = scope.pivots
259    pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None
260    if pivot:
261        # For each aggregation function, the pivot creates a new column for each field in category
262        # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a,
263        # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum'
264        # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs
265        # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest
266        # in the lineage, so lookup the pivot column name by index and map that with the columns used
267        # in the aggregation.
268        #
269        # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b')
270        pivot_columns = pivot.args["columns"]
271        pivot_aggs_count = len(pivot.expressions)
272
273        pivot_column_mapping = {}
274        for i, agg in enumerate(pivot.expressions):
275            agg_cols = list(agg.find_all(exp.Column))
276            for col_index in range(i, len(pivot_columns), pivot_aggs_count):
277                pivot_column_mapping[pivot_columns[col_index].name] = agg_cols
278
279    for c in source_columns:
280        table = c.table
281        source = scope.sources.get(table)
282
283        if isinstance(source, Scope):
284            reference_node_name = None
285            if source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names:
286                reference_node_name = table
287            elif source.scope_type == ScopeType.CTE:
288                selected_node, _ = scope.selected_sources.get(table, (None, None))
289                reference_node_name = selected_node.name if selected_node else None
290
291            # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
292            to_node(
293                c.name,
294                scope=source,
295                dialect=dialect,
296                scope_name=table,
297                upstream=node,
298                source_name=source_names.get(table) or source_name,
299                reference_node_name=reference_node_name,
300                trim_selects=trim_selects,
301            )
302        elif pivot and pivot.alias_or_name == c.table:
303            downstream_columns = []
304
305            column_name = c.name
306            if any(column_name == pivot_column.name for pivot_column in pivot_columns):
307                downstream_columns.extend(pivot_column_mapping[column_name])
308            else:
309                # The column is not in the pivot, so it must be an implicit column of the
310                # pivoted source -- adapt column to be from the implicit pivoted source.
311                downstream_columns.append(exp.column(c.this, table=pivot.parent.this))
312
313            for downstream_column in downstream_columns:
314                table = downstream_column.table
315                source = scope.sources.get(table)
316                if isinstance(source, Scope):
317                    to_node(
318                        downstream_column.name,
319                        scope=source,
320                        scope_name=table,
321                        dialect=dialect,
322                        upstream=node,
323                        source_name=source_names.get(table) or source_name,
324                        reference_node_name=reference_node_name,
325                        trim_selects=trim_selects,
326                    )
327                else:
328                    source = source or exp.Placeholder()
329                    node.downstream.append(
330                        Node(
331                            name=downstream_column.sql(comments=False),
332                            source=source,
333                            expression=source,
334                        )
335                    )
336        else:
337            # The source is not a scope and the column is not in any pivot - we've reached the end
338            # of the line. At this point, if a source is not found it means this column's lineage
339            # is unknown. This can happen if the definition of a source used in a query is not
340            # passed into the `sources` map.
341            source = source or exp.Placeholder()
342            node.downstream.append(
343                Node(name=c.sql(comments=False), source=source, expression=source)
344            )
345
346    return node
class GraphHTML:
349class GraphHTML:
350    """Node to HTML generator using vis.js.
351
352    https://visjs.github.io/vis-network/docs/network/
353    """
354
355    def __init__(
356        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
357    ):
358        self.imports = imports
359
360        self.options = {
361            "height": "500px",
362            "width": "100%",
363            "layout": {
364                "hierarchical": {
365                    "enabled": True,
366                    "nodeSpacing": 200,
367                    "sortMethod": "directed",
368                },
369            },
370            "interaction": {
371                "dragNodes": False,
372                "selectable": False,
373            },
374            "physics": {
375                "enabled": False,
376            },
377            "edges": {
378                "arrows": "to",
379            },
380            "nodes": {
381                "font": "20px monaco",
382                "shape": "box",
383                "widthConstraint": {
384                    "maximum": 300,
385                },
386            },
387            **(options or {}),
388        }
389
390        self.nodes = nodes
391        self.edges = edges
392
393    def __str__(self):
394        nodes = json.dumps(list(self.nodes.values()))
395        edges = json.dumps(self.edges)
396        options = json.dumps(self.options)
397        imports = (
398            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
399  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
400  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
401            if self.imports
402            else ""
403        )
404
405        return f"""<div>
406  <div id="sqlglot-lineage"></div>
407  {imports}
408  <script type="text/javascript">
409    var nodes = new vis.DataSet({nodes})
410    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
411
412    new vis.Network(
413        document.getElementById("sqlglot-lineage"),
414        {{
415            nodes: nodes,
416            edges: new vis.DataSet({edges})
417        }},
418        {options},
419    )
420  </script>
421</div>"""
422
423    def _repr_html_(self) -> str:
424        return self.__str__()

Node to HTML generator using vis.js.

https://visjs.github.io/vis-network/docs/network/

GraphHTML( nodes: Dict, edges: List, imports: bool = True, options: Optional[Dict] = None)
355    def __init__(
356        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
357    ):
358        self.imports = imports
359
360        self.options = {
361            "height": "500px",
362            "width": "100%",
363            "layout": {
364                "hierarchical": {
365                    "enabled": True,
366                    "nodeSpacing": 200,
367                    "sortMethod": "directed",
368                },
369            },
370            "interaction": {
371                "dragNodes": False,
372                "selectable": False,
373            },
374            "physics": {
375                "enabled": False,
376            },
377            "edges": {
378                "arrows": "to",
379            },
380            "nodes": {
381                "font": "20px monaco",
382                "shape": "box",
383                "widthConstraint": {
384                    "maximum": 300,
385                },
386            },
387            **(options or {}),
388        }
389
390        self.nodes = nodes
391        self.edges = edges
imports
options
nodes
edges