Edit on GitHub

sqlglot.lineage

  1from __future__ import annotations
  2
  3import json
  4import logging
  5import typing as t
  6from dataclasses import dataclass, field
  7
  8from sqlglot import Schema, exp, maybe_parse
  9from sqlglot.errors import SqlglotError
 10from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify
 11from sqlglot.optimizer.scope import ScopeType
 12
 13if t.TYPE_CHECKING:
 14    from sqlglot.dialects.dialect import DialectType
 15    from collections.abc import Iterator, Mapping, Sequence
 16    from sqlglot._typing import GraphHTMLArgs
 17    from typing_extensions import Unpack
 18
 19logger = logging.getLogger("sqlglot")
 20
 21
 22@dataclass(frozen=True)
 23class Node:
 24    name: str
 25    expression: exp.Expr
 26    source: exp.Expr
 27    downstream: list[Node] = field(default_factory=list)
 28    source_name: str = ""
 29    reference_node_name: str = ""
 30
 31    def walk(self) -> Iterator[Node]:
 32        visited: set[int] = set()
 33        queue = [self]
 34        while queue:
 35            node = queue.pop()
 36            node_id = id(node)
 37            if node_id in visited:
 38                continue
 39            visited.add(node_id)
 40            yield node
 41            queue.extend(reversed(node.downstream))
 42
 43    def to_html(self, dialect: DialectType = None, **opts: Unpack[GraphHTMLArgs]) -> GraphHTML:
 44        nodes = {}
 45        edges = []
 46
 47        for node in self.walk():
 48            if isinstance(node.expression, exp.Table):
 49                label = f"FROM {node.expression.this}"
 50                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
 51                group = 1
 52            else:
 53                label = node.expression.sql(pretty=True, dialect=dialect)
 54                source = node.source.transform(
 55                    lambda n: (
 56                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
 57                    ),
 58                    copy=False,
 59                ).sql(pretty=True, dialect=dialect)
 60                title = f"<pre>{source}</pre>"
 61                group = 0
 62
 63            node_id = id(node)
 64
 65            nodes[node_id] = {
 66                "id": node_id,
 67                "label": label,
 68                "title": title,
 69                "group": group,
 70            }
 71
 72            for d in node.downstream:
 73                edges.append({"from": node_id, "to": id(d)})
 74        return GraphHTML(nodes, edges, **opts)
 75
 76
 77def lineage(
 78    column: str | exp.Column,
 79    sql: str | exp.Expr,
 80    schema: dict | Schema | None = None,
 81    sources: Mapping[str, str | exp.Query] | None = None,
 82    dialect: DialectType = None,
 83    scope: Scope | None = None,
 84    trim_selects: bool = True,
 85    copy: bool = True,
 86    **kwargs,
 87) -> Node:
 88    """Build the lineage graph for a column of a SQL query.
 89
 90    Args:
 91        column: The column to build the lineage for.
 92        sql: The SQL string or expression.
 93        schema: The schema of tables.
 94        sources: A mapping of queries which will be used to continue building lineage.
 95        dialect: The dialect of input SQL.
 96        scope: A pre-created scope to use instead.
 97        trim_selects: Whether to clean up selects by trimming to only relevant columns.
 98        copy: Whether to copy the Expr arguments.
 99        **kwargs: Qualification optimizer kwargs.
100
101    Returns:
102        A lineage node.
103    """
104
105    expression = maybe_parse(sql, copy=copy, dialect=dialect)
106    column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
107
108    if sources:
109        expression = exp.expand(
110            expression,
111            {
112                k: t.cast(exp.Query, maybe_parse(v, copy=copy, dialect=dialect))
113                for k, v in sources.items()
114            },
115            dialect=dialect,
116            copy=copy,
117        )
118
119    if not scope:
120        expression = qualify.qualify(
121            expression,
122            dialect=dialect,
123            schema=schema,
124            **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
125        )
126
127        scope = build_scope(expression)
128
129    if not scope:
130        raise SqlglotError("Cannot build lineage, sql must be SELECT")
131
132    selectable = scope.expression
133    if not isinstance(selectable, exp.Selectable) or not any(
134        select.alias_or_name == column for select in selectable.selects
135    ):
136        raise SqlglotError(f"Cannot find column '{column}' in query.")
137
138    return to_node(column, scope, dialect, trim_selects=trim_selects, _cache={})
139
140
141def to_node(
142    column: str | int,
143    scope: Scope,
144    dialect: DialectType,
145    scope_name: str | None = None,
146    upstream: Node | None = None,
147    source_name: str | None = None,
148    reference_node_name: str | None = None,
149    trim_selects: bool = True,
150    _cache: dict[tuple, Node] | None = None,
151) -> Node:
152    cache_key = (column, id(scope), scope_name, source_name, reference_node_name)
153
154    if _cache is not None and cache_key in _cache:
155        cached_node = _cache[cache_key]
156        if upstream:
157            upstream.downstream.append(cached_node)
158        return cached_node
159
160    # Find the specific select clause that is the source of the column we want.
161    # This can either be a specific, named select or a generic `*` clause.
162    selectable = t.cast(exp.Selectable, scope.expression)
163    if isinstance(column, int):
164        if column >= len(selectable.selects):
165            raise SqlglotError(
166                f"Cannot find column's source with index {column} in query: {selectable.sql(dialect=dialect)}"
167            )
168        select = selectable.selects[column]
169    else:
170        select = next(
171            (select for select in selectable.selects if select.alias_or_name == column),
172            exp.Star() if selectable.is_star else scope.expression,
173        )
174
175    if isinstance(scope.expression, exp.Subquery):
176        for inner_scope in scope.subquery_scopes:
177            result = to_node(
178                column,
179                scope=inner_scope,
180                dialect=dialect,
181                upstream=upstream,
182                source_name=source_name,
183                reference_node_name=reference_node_name,
184                trim_selects=trim_selects,
185                _cache=_cache,
186            )
187            if _cache is not None:
188                _cache[cache_key] = result
189            return result
190    if isinstance(scope.expression, exp.SetOperation):
191        name = type(scope.expression).__name__.upper()
192        upstream = upstream or Node(name=name, source=scope.expression, expression=select)
193
194        index = (
195            column
196            if isinstance(column, int)
197            else next(
198                (
199                    i
200                    for i, select in enumerate(selectable.selects)
201                    if select.alias_or_name == column or select.is_star
202                ),
203                -1,  # mypy will not allow a None here, but a negative index should never be returned
204            )
205        )
206
207        if index == -1:
208            raise ValueError(f"Could not find {column} in {scope.expression}")
209
210        for s in scope.union_scopes:
211            to_node(
212                index,
213                scope=s,
214                dialect=dialect,
215                upstream=upstream,
216                source_name=source_name,
217                reference_node_name=reference_node_name,
218                trim_selects=trim_selects,
219                _cache=_cache,
220            )
221
222        if _cache is not None:
223            _cache[cache_key] = upstream
224        return upstream
225
226    if trim_selects and isinstance(scope.expression, exp.Select):
227        # For better ergonomics in our node labels, replace the full select with
228        # a version that has only the column we care about.
229        #   "x", SELECT x, y FROM foo
230        #     => "x", SELECT x FROM foo
231        source: exp.Expr = scope.expression.select(select, append=False)
232    else:
233        source = scope.expression
234
235    # Create the node for this step in the lineage chain, and attach it to the previous one.
236    node = Node(
237        name=f"{scope_name}.{column}" if scope_name else str(column),
238        source=source,
239        expression=select,
240        source_name=source_name or "",
241        reference_node_name=reference_node_name or "",
242    )
243
244    if upstream:
245        upstream.downstream.append(node)
246
247    subquery_scopes = {
248        id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes
249    }
250
251    for subquery in find_all_in_scope(select, *exp.UNWRAPPED_QUERIES):
252        subquery_scope: Scope | None = subquery_scopes.get(id(subquery))
253        if not subquery_scope:
254            logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
255            continue
256
257        for name in subquery.named_selects:
258            to_node(
259                name,
260                scope=subquery_scope,
261                dialect=dialect,
262                upstream=node,
263                trim_selects=trim_selects,
264                _cache=_cache,
265            )
266
267    # if the select is a star add all scope sources as downstreams
268    if isinstance(select, exp.Star):
269        for src in scope.sources.values():
270            src_expr = src.expression if isinstance(src, Scope) else src
271            node.downstream.append(
272                Node(name=select.sql(comments=False), source=src_expr, expression=src_expr)
273            )
274
275    # Find all columns that went into creating this one to list their lineage nodes.
276    source_columns = set(find_all_in_scope(select, exp.Column))
277
278    # If the source is a UDTF find columns used in the UDTF to generate the table
279    if isinstance(source, exp.UDTF):
280        source_columns |= set(source.find_all(exp.Column))
281        derived_tables: Sequence[exp.Expr] = [
282            src.expression.parent
283            for src in scope.sources.values()
284            if isinstance(src, Scope) and src.is_derived_table and src.expression.parent
285        ]
286    else:
287        derived_tables = scope.derived_tables
288
289    source_names = {
290        dt.alias: dt.comments[0].split()[1]
291        for dt in derived_tables
292        if dt.comments and dt.comments[0].startswith("source: ")
293    }
294
295    pivots = scope.pivots
296    pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None
297    if pivot:
298        # For each aggregation function, the pivot creates a new column for each field in category
299        # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a,
300        # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum'
301        # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs
302        # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest
303        # in the lineage, so lookup the pivot column name by index and map that with the columns used
304        # in the aggregation.
305        #
306        # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b')
307        pivot_columns = pivot.args["columns"]
308        pivot_aggs_count = len(pivot.expressions)
309
310        pivot_column_mapping = {}
311        for i, agg in enumerate(pivot.expressions):
312            agg_cols = list(agg.find_all(exp.Column))
313            for col_index in range(i, len(pivot_columns), pivot_aggs_count):
314                pivot_column_mapping[pivot_columns[col_index].name] = agg_cols
315
316    for c in source_columns:
317        table = c.table
318        col_source: exp.Table | Scope | None = scope.sources.get(table)
319
320        if isinstance(col_source, Scope):
321            reference_node_name = None
322            if col_source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names:
323                reference_node_name = table
324            elif col_source.scope_type == ScopeType.CTE:
325                selected_node, _ = scope.selected_sources.get(table, (None, None))
326                reference_node_name = selected_node.name if selected_node else None
327
328            # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
329            to_node(
330                c.name,
331                scope=col_source,
332                dialect=dialect,
333                scope_name=table,
334                upstream=node,
335                source_name=source_names.get(table) or source_name,
336                reference_node_name=reference_node_name,
337                trim_selects=trim_selects,
338                _cache=_cache,
339            )
340        elif pivot and pivot.alias_or_name == c.table:
341            downstream_columns = []
342
343            column_name = c.name
344            if any(column_name == pivot_column.name for pivot_column in pivot_columns):
345                downstream_columns.extend(pivot_column_mapping[column_name])
346            else:
347                # The column is not in the pivot, so it must be an implicit column of the
348                # pivoted source -- adapt column to be from the implicit pivoted source.
349                pivot_parent = pivot.parent
350                downstream_columns.append(
351                    exp.column(c.this, table=pivot_parent.alias_or_name if pivot_parent else "")
352                )
353
354            for downstream_column in downstream_columns:
355                table = downstream_column.table
356                col_source = scope.sources.get(table)
357                if isinstance(col_source, Scope):
358                    to_node(
359                        downstream_column.name,
360                        scope=col_source,
361                        scope_name=table,
362                        dialect=dialect,
363                        upstream=node,
364                        source_name=source_names.get(table) or source_name,
365                        reference_node_name=reference_node_name,
366                        trim_selects=trim_selects,
367                        _cache=_cache,
368                    )
369                else:
370                    col_expr = col_source or exp.Placeholder()
371                    node.downstream.append(
372                        Node(
373                            name=downstream_column.sql(comments=False),
374                            source=col_expr,
375                            expression=col_expr,
376                        )
377                    )
378        else:
379            # The source is not a scope and the column is not in any pivot - we've reached the end
380            # of the line. At this point, if a source is not found it means this column's lineage
381            # is unknown. This can happen if the definition of a source used in a query is not
382            # passed into the `sources` map.
383            col_expr = col_source or exp.Placeholder()
384            node.downstream.append(
385                Node(name=c.sql(comments=False), source=col_expr, expression=col_expr)
386            )
387
388    if _cache is not None:
389        _cache[cache_key] = node
390
391    return node
392
393
394class GraphHTML:
395    """Node to HTML generator using vis.js.
396
397    https://visjs.github.io/vis-network/docs/network/
398    """
399
400    def __init__(
401        self,
402        nodes: dict,
403        edges: list,
404        imports: bool = True,
405        options: Mapping[str, object] | None = None,
406    ):
407        self.imports = imports
408
409        self.options = {
410            "height": "500px",
411            "width": "100%",
412            "layout": {
413                "hierarchical": {
414                    "enabled": True,
415                    "nodeSpacing": 200,
416                    "sortMethod": "directed",
417                },
418            },
419            "interaction": {
420                "dragNodes": False,
421                "selectable": False,
422            },
423            "physics": {
424                "enabled": False,
425            },
426            "edges": {
427                "arrows": "to",
428            },
429            "nodes": {
430                "font": "20px monaco",
431                "shape": "box",
432                "widthConstraint": {
433                    "maximum": 300,
434                },
435            },
436            **(options or {}),
437        }
438
439        self.nodes = nodes
440        self.edges = edges
441
442    def __str__(self):
443        nodes = json.dumps(list(self.nodes.values()))
444        edges = json.dumps(self.edges)
445        options = json.dumps(self.options)
446        imports = (
447            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
448  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
449  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
450            if self.imports
451            else ""
452        )
453
454        return f"""<div>
455  <div id="sqlglot-lineage"></div>
456  {imports}
457  <script type="text/javascript">
458    var nodes = new vis.DataSet({nodes})
459    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
460
461    new vis.Network(
462        document.getElementById("sqlglot-lineage"),
463        {{
464            nodes: nodes,
465            edges: new vis.DataSet({edges})
466        }},
467        {options},
468    )
469  </script>
470</div>"""
471
472    def _repr_html_(self) -> str:
473        return self.__str__()
logger = <Logger sqlglot (WARNING)>
@dataclass(frozen=True)
class Node:
23@dataclass(frozen=True)
24class Node:
25    name: str
26    expression: exp.Expr
27    source: exp.Expr
28    downstream: list[Node] = field(default_factory=list)
29    source_name: str = ""
30    reference_node_name: str = ""
31
32    def walk(self) -> Iterator[Node]:
33        visited: set[int] = set()
34        queue = [self]
35        while queue:
36            node = queue.pop()
37            node_id = id(node)
38            if node_id in visited:
39                continue
40            visited.add(node_id)
41            yield node
42            queue.extend(reversed(node.downstream))
43
44    def to_html(self, dialect: DialectType = None, **opts: Unpack[GraphHTMLArgs]) -> GraphHTML:
45        nodes = {}
46        edges = []
47
48        for node in self.walk():
49            if isinstance(node.expression, exp.Table):
50                label = f"FROM {node.expression.this}"
51                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
52                group = 1
53            else:
54                label = node.expression.sql(pretty=True, dialect=dialect)
55                source = node.source.transform(
56                    lambda n: (
57                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
58                    ),
59                    copy=False,
60                ).sql(pretty=True, dialect=dialect)
61                title = f"<pre>{source}</pre>"
62                group = 0
63
64            node_id = id(node)
65
66            nodes[node_id] = {
67                "id": node_id,
68                "label": label,
69                "title": title,
70                "group": group,
71            }
72
73            for d in node.downstream:
74                edges.append({"from": node_id, "to": id(d)})
75        return GraphHTML(nodes, edges, **opts)
Node( name: str, expression: sqlglot.expressions.core.Expr, source: sqlglot.expressions.core.Expr, downstream: list[Node] = <factory>, source_name: str = '', reference_node_name: str = '')
name: str
downstream: list[Node]
source_name: str = ''
reference_node_name: str = ''
def walk(self) -> Iterator[Node]:
32    def walk(self) -> Iterator[Node]:
33        visited: set[int] = set()
34        queue = [self]
35        while queue:
36            node = queue.pop()
37            node_id = id(node)
38            if node_id in visited:
39                continue
40            visited.add(node_id)
41            yield node
42            queue.extend(reversed(node.downstream))
def to_html( self, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, **opts: typing_extensions.Unpack[sqlglot._typing.GraphHTMLArgs]) -> GraphHTML:
44    def to_html(self, dialect: DialectType = None, **opts: Unpack[GraphHTMLArgs]) -> GraphHTML:
45        nodes = {}
46        edges = []
47
48        for node in self.walk():
49            if isinstance(node.expression, exp.Table):
50                label = f"FROM {node.expression.this}"
51                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
52                group = 1
53            else:
54                label = node.expression.sql(pretty=True, dialect=dialect)
55                source = node.source.transform(
56                    lambda n: (
57                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
58                    ),
59                    copy=False,
60                ).sql(pretty=True, dialect=dialect)
61                title = f"<pre>{source}</pre>"
62                group = 0
63
64            node_id = id(node)
65
66            nodes[node_id] = {
67                "id": node_id,
68                "label": label,
69                "title": title,
70                "group": group,
71            }
72
73            for d in node.downstream:
74                edges.append({"from": node_id, "to": id(d)})
75        return GraphHTML(nodes, edges, **opts)
def lineage( column: str | sqlglot.expressions.core.Column, sql: str | sqlglot.expressions.core.Expr, schema: dict | sqlglot.schema.Schema | None = None, sources: Mapping[str, str | sqlglot.expressions.query.Query] | None = None, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, scope: sqlglot.optimizer.scope.Scope | None = None, trim_selects: bool = True, copy: bool = True, **kwargs) -> Node:
 78def lineage(
 79    column: str | exp.Column,
 80    sql: str | exp.Expr,
 81    schema: dict | Schema | None = None,
 82    sources: Mapping[str, str | exp.Query] | None = None,
 83    dialect: DialectType = None,
 84    scope: Scope | None = None,
 85    trim_selects: bool = True,
 86    copy: bool = True,
 87    **kwargs,
 88) -> Node:
 89    """Build the lineage graph for a column of a SQL query.
 90
 91    Args:
 92        column: The column to build the lineage for.
 93        sql: The SQL string or expression.
 94        schema: The schema of tables.
 95        sources: A mapping of queries which will be used to continue building lineage.
 96        dialect: The dialect of input SQL.
 97        scope: A pre-created scope to use instead.
 98        trim_selects: Whether to clean up selects by trimming to only relevant columns.
 99        copy: Whether to copy the Expr arguments.
100        **kwargs: Qualification optimizer kwargs.
101
102    Returns:
103        A lineage node.
104    """
105
106    expression = maybe_parse(sql, copy=copy, dialect=dialect)
107    column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
108
109    if sources:
110        expression = exp.expand(
111            expression,
112            {
113                k: t.cast(exp.Query, maybe_parse(v, copy=copy, dialect=dialect))
114                for k, v in sources.items()
115            },
116            dialect=dialect,
117            copy=copy,
118        )
119
120    if not scope:
121        expression = qualify.qualify(
122            expression,
123            dialect=dialect,
124            schema=schema,
125            **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
126        )
127
128        scope = build_scope(expression)
129
130    if not scope:
131        raise SqlglotError("Cannot build lineage, sql must be SELECT")
132
133    selectable = scope.expression
134    if not isinstance(selectable, exp.Selectable) or not any(
135        select.alias_or_name == column for select in selectable.selects
136    ):
137        raise SqlglotError(f"Cannot find column '{column}' in query.")
138
139    return to_node(column, scope, dialect, trim_selects=trim_selects, _cache={})

Build the lineage graph for a column of a SQL query.

Arguments:
  • column: The column to build the lineage for.
  • sql: The SQL string or expression.
  • schema: The schema of tables.
  • sources: A mapping of queries which will be used to continue building lineage.
  • dialect: The dialect of input SQL.
  • scope: A pre-created scope to use instead.
  • trim_selects: Whether to clean up selects by trimming to only relevant columns.
  • copy: Whether to copy the Expr arguments.
  • **kwargs: Qualification optimizer kwargs.
Returns:

A lineage node.

def to_node( column: str | int, scope: sqlglot.optimizer.scope.Scope, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType], scope_name: str | None = None, upstream: Node | None = None, source_name: str | None = None, reference_node_name: str | None = None, trim_selects: bool = True, _cache: dict[tuple, Node] | None = None) -> Node:
142def to_node(
143    column: str | int,
144    scope: Scope,
145    dialect: DialectType,
146    scope_name: str | None = None,
147    upstream: Node | None = None,
148    source_name: str | None = None,
149    reference_node_name: str | None = None,
150    trim_selects: bool = True,
151    _cache: dict[tuple, Node] | None = None,
152) -> Node:
153    cache_key = (column, id(scope), scope_name, source_name, reference_node_name)
154
155    if _cache is not None and cache_key in _cache:
156        cached_node = _cache[cache_key]
157        if upstream:
158            upstream.downstream.append(cached_node)
159        return cached_node
160
161    # Find the specific select clause that is the source of the column we want.
162    # This can either be a specific, named select or a generic `*` clause.
163    selectable = t.cast(exp.Selectable, scope.expression)
164    if isinstance(column, int):
165        if column >= len(selectable.selects):
166            raise SqlglotError(
167                f"Cannot find column's source with index {column} in query: {selectable.sql(dialect=dialect)}"
168            )
169        select = selectable.selects[column]
170    else:
171        select = next(
172            (select for select in selectable.selects if select.alias_or_name == column),
173            exp.Star() if selectable.is_star else scope.expression,
174        )
175
176    if isinstance(scope.expression, exp.Subquery):
177        for inner_scope in scope.subquery_scopes:
178            result = to_node(
179                column,
180                scope=inner_scope,
181                dialect=dialect,
182                upstream=upstream,
183                source_name=source_name,
184                reference_node_name=reference_node_name,
185                trim_selects=trim_selects,
186                _cache=_cache,
187            )
188            if _cache is not None:
189                _cache[cache_key] = result
190            return result
191    if isinstance(scope.expression, exp.SetOperation):
192        name = type(scope.expression).__name__.upper()
193        upstream = upstream or Node(name=name, source=scope.expression, expression=select)
194
195        index = (
196            column
197            if isinstance(column, int)
198            else next(
199                (
200                    i
201                    for i, select in enumerate(selectable.selects)
202                    if select.alias_or_name == column or select.is_star
203                ),
204                -1,  # mypy will not allow a None here, but a negative index should never be returned
205            )
206        )
207
208        if index == -1:
209            raise ValueError(f"Could not find {column} in {scope.expression}")
210
211        for s in scope.union_scopes:
212            to_node(
213                index,
214                scope=s,
215                dialect=dialect,
216                upstream=upstream,
217                source_name=source_name,
218                reference_node_name=reference_node_name,
219                trim_selects=trim_selects,
220                _cache=_cache,
221            )
222
223        if _cache is not None:
224            _cache[cache_key] = upstream
225        return upstream
226
227    if trim_selects and isinstance(scope.expression, exp.Select):
228        # For better ergonomics in our node labels, replace the full select with
229        # a version that has only the column we care about.
230        #   "x", SELECT x, y FROM foo
231        #     => "x", SELECT x FROM foo
232        source: exp.Expr = scope.expression.select(select, append=False)
233    else:
234        source = scope.expression
235
236    # Create the node for this step in the lineage chain, and attach it to the previous one.
237    node = Node(
238        name=f"{scope_name}.{column}" if scope_name else str(column),
239        source=source,
240        expression=select,
241        source_name=source_name or "",
242        reference_node_name=reference_node_name or "",
243    )
244
245    if upstream:
246        upstream.downstream.append(node)
247
248    subquery_scopes = {
249        id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes
250    }
251
252    for subquery in find_all_in_scope(select, *exp.UNWRAPPED_QUERIES):
253        subquery_scope: Scope | None = subquery_scopes.get(id(subquery))
254        if not subquery_scope:
255            logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
256            continue
257
258        for name in subquery.named_selects:
259            to_node(
260                name,
261                scope=subquery_scope,
262                dialect=dialect,
263                upstream=node,
264                trim_selects=trim_selects,
265                _cache=_cache,
266            )
267
268    # if the select is a star add all scope sources as downstreams
269    if isinstance(select, exp.Star):
270        for src in scope.sources.values():
271            src_expr = src.expression if isinstance(src, Scope) else src
272            node.downstream.append(
273                Node(name=select.sql(comments=False), source=src_expr, expression=src_expr)
274            )
275
276    # Find all columns that went into creating this one to list their lineage nodes.
277    source_columns = set(find_all_in_scope(select, exp.Column))
278
279    # If the source is a UDTF find columns used in the UDTF to generate the table
280    if isinstance(source, exp.UDTF):
281        source_columns |= set(source.find_all(exp.Column))
282        derived_tables: Sequence[exp.Expr] = [
283            src.expression.parent
284            for src in scope.sources.values()
285            if isinstance(src, Scope) and src.is_derived_table and src.expression.parent
286        ]
287    else:
288        derived_tables = scope.derived_tables
289
290    source_names = {
291        dt.alias: dt.comments[0].split()[1]
292        for dt in derived_tables
293        if dt.comments and dt.comments[0].startswith("source: ")
294    }
295
296    pivots = scope.pivots
297    pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None
298    if pivot:
299        # For each aggregation function, the pivot creates a new column for each field in category
300        # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a,
301        # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum'
302        # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs
303        # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest
304        # in the lineage, so lookup the pivot column name by index and map that with the columns used
305        # in the aggregation.
306        #
307        # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b')
308        pivot_columns = pivot.args["columns"]
309        pivot_aggs_count = len(pivot.expressions)
310
311        pivot_column_mapping = {}
312        for i, agg in enumerate(pivot.expressions):
313            agg_cols = list(agg.find_all(exp.Column))
314            for col_index in range(i, len(pivot_columns), pivot_aggs_count):
315                pivot_column_mapping[pivot_columns[col_index].name] = agg_cols
316
317    for c in source_columns:
318        table = c.table
319        col_source: exp.Table | Scope | None = scope.sources.get(table)
320
321        if isinstance(col_source, Scope):
322            reference_node_name = None
323            if col_source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names:
324                reference_node_name = table
325            elif col_source.scope_type == ScopeType.CTE:
326                selected_node, _ = scope.selected_sources.get(table, (None, None))
327                reference_node_name = selected_node.name if selected_node else None
328
329            # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
330            to_node(
331                c.name,
332                scope=col_source,
333                dialect=dialect,
334                scope_name=table,
335                upstream=node,
336                source_name=source_names.get(table) or source_name,
337                reference_node_name=reference_node_name,
338                trim_selects=trim_selects,
339                _cache=_cache,
340            )
341        elif pivot and pivot.alias_or_name == c.table:
342            downstream_columns = []
343
344            column_name = c.name
345            if any(column_name == pivot_column.name for pivot_column in pivot_columns):
346                downstream_columns.extend(pivot_column_mapping[column_name])
347            else:
348                # The column is not in the pivot, so it must be an implicit column of the
349                # pivoted source -- adapt column to be from the implicit pivoted source.
350                pivot_parent = pivot.parent
351                downstream_columns.append(
352                    exp.column(c.this, table=pivot_parent.alias_or_name if pivot_parent else "")
353                )
354
355            for downstream_column in downstream_columns:
356                table = downstream_column.table
357                col_source = scope.sources.get(table)
358                if isinstance(col_source, Scope):
359                    to_node(
360                        downstream_column.name,
361                        scope=col_source,
362                        scope_name=table,
363                        dialect=dialect,
364                        upstream=node,
365                        source_name=source_names.get(table) or source_name,
366                        reference_node_name=reference_node_name,
367                        trim_selects=trim_selects,
368                        _cache=_cache,
369                    )
370                else:
371                    col_expr = col_source or exp.Placeholder()
372                    node.downstream.append(
373                        Node(
374                            name=downstream_column.sql(comments=False),
375                            source=col_expr,
376                            expression=col_expr,
377                        )
378                    )
379        else:
380            # The source is not a scope and the column is not in any pivot - we've reached the end
381            # of the line. At this point, if a source is not found it means this column's lineage
382            # is unknown. This can happen if the definition of a source used in a query is not
383            # passed into the `sources` map.
384            col_expr = col_source or exp.Placeholder()
385            node.downstream.append(
386                Node(name=c.sql(comments=False), source=col_expr, expression=col_expr)
387            )
388
389    if _cache is not None:
390        _cache[cache_key] = node
391
392    return node
class GraphHTML:
395class GraphHTML:
396    """Node to HTML generator using vis.js.
397
398    https://visjs.github.io/vis-network/docs/network/
399    """
400
401    def __init__(
402        self,
403        nodes: dict,
404        edges: list,
405        imports: bool = True,
406        options: Mapping[str, object] | None = None,
407    ):
408        self.imports = imports
409
410        self.options = {
411            "height": "500px",
412            "width": "100%",
413            "layout": {
414                "hierarchical": {
415                    "enabled": True,
416                    "nodeSpacing": 200,
417                    "sortMethod": "directed",
418                },
419            },
420            "interaction": {
421                "dragNodes": False,
422                "selectable": False,
423            },
424            "physics": {
425                "enabled": False,
426            },
427            "edges": {
428                "arrows": "to",
429            },
430            "nodes": {
431                "font": "20px monaco",
432                "shape": "box",
433                "widthConstraint": {
434                    "maximum": 300,
435                },
436            },
437            **(options or {}),
438        }
439
440        self.nodes = nodes
441        self.edges = edges
442
443    def __str__(self):
444        nodes = json.dumps(list(self.nodes.values()))
445        edges = json.dumps(self.edges)
446        options = json.dumps(self.options)
447        imports = (
448            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
449  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
450  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
451            if self.imports
452            else ""
453        )
454
455        return f"""<div>
456  <div id="sqlglot-lineage"></div>
457  {imports}
458  <script type="text/javascript">
459    var nodes = new vis.DataSet({nodes})
460    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
461
462    new vis.Network(
463        document.getElementById("sqlglot-lineage"),
464        {{
465            nodes: nodes,
466            edges: new vis.DataSet({edges})
467        }},
468        {options},
469    )
470  </script>
471</div>"""
472
473    def _repr_html_(self) -> str:
474        return self.__str__()

Node to HTML generator using vis.js.

https://visjs.github.io/vis-network/docs/network/

GraphHTML( nodes: dict, edges: list, imports: bool = True, options: Mapping[str, object] | None = None)
401    def __init__(
402        self,
403        nodes: dict,
404        edges: list,
405        imports: bool = True,
406        options: Mapping[str, object] | None = None,
407    ):
408        self.imports = imports
409
410        self.options = {
411            "height": "500px",
412            "width": "100%",
413            "layout": {
414                "hierarchical": {
415                    "enabled": True,
416                    "nodeSpacing": 200,
417                    "sortMethod": "directed",
418                },
419            },
420            "interaction": {
421                "dragNodes": False,
422                "selectable": False,
423            },
424            "physics": {
425                "enabled": False,
426            },
427            "edges": {
428                "arrows": "to",
429            },
430            "nodes": {
431                "font": "20px monaco",
432                "shape": "box",
433                "widthConstraint": {
434                    "maximum": 300,
435                },
436            },
437            **(options or {}),
438        }
439
440        self.nodes = nodes
441        self.edges = edges
imports
options
nodes
edges