Edit on GitHub

sqlglot.lineage

  1from __future__ import annotations
  2
  3import json
  4import logging
  5import typing as t
  6from dataclasses import dataclass, field
  7
  8from sqlglot import Schema, exp, maybe_parse
  9from sqlglot.errors import SqlglotError
 10from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify
 11from sqlglot.optimizer.scope import ScopeType
 12
 13if t.TYPE_CHECKING:
 14    from sqlglot.dialects.dialect import DialectType
 15    from collections.abc import Iterator, Mapping, Sequence
 16    from sqlglot._typing import GraphHTMLArgs
 17    from typing_extensions import Unpack
 18
 19logger = logging.getLogger("sqlglot")
 20
 21
 22@dataclass(frozen=True)
 23class Node:
 24    name: str
 25    expression: exp.Expr
 26    source: exp.Expr
 27    downstream: list[Node] = field(default_factory=list)
 28    source_name: str = ""
 29    reference_node_name: str = ""
 30
 31    # Caller-injected per-node data, populated via the `on_node` hook on lineage()
 32    payload: dict[str, t.Any] = field(default_factory=dict)
 33
 34    def walk(self) -> Iterator[Node]:
 35        visited: set[int] = set()
 36        queue = [self]
 37        while queue:
 38            node = queue.pop()
 39            node_id = id(node)
 40            if node_id in visited:
 41                continue
 42            visited.add(node_id)
 43            yield node
 44            queue.extend(reversed(node.downstream))
 45
 46    def to_html(self, dialect: DialectType = None, **opts: Unpack[GraphHTMLArgs]) -> GraphHTML:
 47        nodes = {}
 48        edges = []
 49
 50        for node in self.walk():
 51            if isinstance(node.expression, exp.Table):
 52                label = f"FROM {node.expression.this}"
 53                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
 54                group = 1
 55            else:
 56                label = node.expression.sql(pretty=True, dialect=dialect)
 57                source = node.source.transform(
 58                    lambda n: (
 59                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
 60                    ),
 61                    copy=False,
 62                ).sql(pretty=True, dialect=dialect)
 63                title = f"<pre>{source}</pre>"
 64                group = 0
 65
 66            node_id = id(node)
 67
 68            nodes[node_id] = {
 69                "id": node_id,
 70                "label": label,
 71                "title": title,
 72                "group": group,
 73            }
 74
 75            for d in node.downstream:
 76                edges.append({"from": node_id, "to": id(d)})
 77        return GraphHTML(nodes, edges, **opts)
 78
 79
 80@t.overload
 81def lineage(column: str | exp.Column, sql: str | exp.Expr, **kwargs: t.Any) -> Node: ...
 82
 83
 84@t.overload
 85def lineage(column: None, sql: str | exp.Expr, **kwargs: t.Any) -> dict[str, Node]: ...
 86
 87
 88def lineage(
 89    column: str | exp.Column | None,
 90    sql: str | exp.Expr,
 91    schema: dict | Schema | None = None,
 92    sources: Mapping[str, str | exp.Query] | None = None,
 93    dialect: DialectType = None,
 94    scope: Scope | None = None,
 95    trim_selects: bool = True,
 96    copy: bool = True,
 97    on_node: t.Callable[[Node], None] | None = None,
 98    **kwargs,
 99) -> Node | dict[str, Node]:
100    """Build the lineage graph for a SQL query.
101
102    If `column` is given, returns the lineage Node for that single output column.
103    If `column` is None, returns a dict mapping every top-level output column name
104    to its lineage Node (with a shared cache so cross-column work is deduplicated).
105
106    Args:
107        column: The column to build the lineage for. Pass None to get all output columns.
108        sql: The SQL string or expression.
109        schema: The schema of tables.
110        sources: A mapping of queries which will be used to continue building lineage.
111        dialect: The dialect of input SQL.
112        scope: A pre-created scope to use instead.
113        trim_selects: Whether to clean up selects by trimming to only relevant columns.
114        copy: Whether to copy the Expr arguments.
115        on_node: Optional callback invoked for every Node created during the walk,
116            after the Node's downstream is populated. Useful for injecting
117            caller-managed data into Node.payload during the walk.
118        **kwargs: Qualification optimizer kwargs.
119
120    Returns:
121        A Node when `column` is provided, or a dict[str, Node] when `column` is None.
122    """
123    expression = maybe_parse(sql, copy=copy, dialect=dialect)
124
125    if sources:
126        expression = exp.expand(
127            expression,
128            {
129                k: t.cast(exp.Query, maybe_parse(v, copy=copy, dialect=dialect))
130                for k, v in sources.items()
131            },
132            dialect=dialect,
133            copy=copy,
134        )
135
136    if not scope:
137        expression = qualify.qualify(
138            expression,
139            dialect=dialect,
140            schema=schema,
141            **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
142        )
143        scope = build_scope(expression)
144
145    if not scope:
146        raise SqlglotError("Cannot build lineage, sql must be SELECT")
147
148    selectable = scope.expression
149    if not isinstance(selectable, exp.Selectable):
150        raise SqlglotError("Cannot build lineage, sql must be a query")
151
152    cache: dict[tuple, Node] = {}
153    scope_meta: dict[int, tuple[bool, dict[str, exp.Expr]]] = {}
154
155    if column is not None:
156        column_name = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
157        if not any(select.alias_or_name == column_name for select in selectable.selects):
158            raise SqlglotError(f"Cannot find column '{column_name}' in query.")
159
160        return to_node(
161            column_name,
162            scope,
163            dialect,
164            trim_selects=trim_selects,
165            _cache=cache,
166            _scope_meta=scope_meta,
167            on_node=on_node,
168        )
169
170    result: dict[str, Node] = {}
171    for sel in selectable.selects:
172        name = sel.alias_or_name
173        if not name:
174            raise SqlglotError(
175                f"Cannot fetch lineage for unnamed projection: {sel.sql(dialect=dialect)}."
176            )
177
178        result[name] = to_node(
179            name,
180            scope,
181            dialect,
182            trim_selects=trim_selects,
183            _cache=cache,
184            _scope_meta=scope_meta,
185            on_node=on_node,
186        )
187
188    return result
189
190
191def to_node(
192    column: str | int,
193    scope: Scope,
194    dialect: DialectType,
195    scope_name: str | None = None,
196    upstream: Node | None = None,
197    source_name: str | None = None,
198    reference_node_name: str | None = None,
199    trim_selects: bool = True,
200    _cache: dict[tuple, Node] | None = None,
201    _scope_meta: dict[int, tuple[bool, dict[str, exp.Expr]]] | None = None,
202    on_node: t.Callable[[Node], None] | None = None,
203) -> Node:
204    cache_key = (column, id(scope), scope_name, source_name, reference_node_name)
205
206    if _cache is not None and cache_key in _cache:
207        cached_node = _cache[cache_key]
208        if upstream:
209            upstream.downstream.append(cached_node)
210        return cached_node
211
212    # Find the specific select clause that is the source of the column we want.
213    # This can either be a specific, named select or a generic `*` clause.
214    selectable = t.cast(exp.Selectable, scope.expression)
215    if isinstance(column, int):
216        if column >= len(selectable.selects):
217            raise SqlglotError(
218                f"Cannot find column's source with index {column} in query: {selectable.sql(dialect=dialect)}"
219            )
220        select = selectable.selects[column]
221    else:
222        # Resolving a column to its select scans selectable.selects on every call;
223        # memoize a per-scope {name: select} map and is_star bit instead.
224        if _scope_meta is None:
225            select = next(
226                (s for s in selectable.selects if s.alias_or_name == column),
227                exp.Star() if selectable.is_star else scope.expression,
228            )
229        else:
230            scope_id = id(scope)
231            meta = _scope_meta.get(scope_id)
232            if meta is None:
233                select_by_name: dict[str, exp.Expr] = {}
234                for sel in selectable.selects:
235                    select_by_name.setdefault(sel.alias_or_name, sel)
236                meta = (selectable.is_star, select_by_name)
237                _scope_meta[scope_id] = meta
238            is_star, select_by_name = meta
239            select = select_by_name.get(column, exp.Star() if is_star else scope.expression)
240
241    if isinstance(scope.expression, exp.Subquery):
242        for inner_scope in scope.subquery_scopes:
243            result = to_node(
244                column,
245                scope=inner_scope,
246                dialect=dialect,
247                upstream=upstream,
248                source_name=source_name,
249                reference_node_name=reference_node_name,
250                trim_selects=trim_selects,
251                _cache=_cache,
252                _scope_meta=_scope_meta,
253                on_node=on_node,
254            )
255            # Skip caching a passed-in upstream returned by an inner SetOp:
256            # a sibling call at the same key with that node as its upstream
257            # would otherwise self-loop on the cache hit.
258            if _cache is not None and result is not upstream:
259                _cache[cache_key] = result
260            return result
261    if isinstance(scope.expression, exp.SetOperation):
262        name = type(scope.expression).__name__.upper()
263        created_setop = upstream is None
264        upstream = upstream or Node(name=name, source=scope.expression, expression=select)
265
266        index = (
267            column
268            if isinstance(column, int)
269            else next(
270                (
271                    i
272                    for i, select in enumerate(selectable.selects)
273                    if select.alias_or_name == column or select.is_star
274                ),
275                -1,  # mypy will not allow a None here, but a negative index should never be returned
276            )
277        )
278
279        if index == -1:
280            raise ValueError(f"Could not find {column} in {scope.expression}")
281
282        for s in scope.union_scopes:
283            to_node(
284                index,
285                scope=s,
286                dialect=dialect,
287                upstream=upstream,
288                source_name=source_name,
289                reference_node_name=reference_node_name,
290                trim_selects=trim_selects,
291                _cache=_cache,
292                _scope_meta=_scope_meta,
293                on_node=on_node,
294            )
295
296        if _cache is not None and created_setop:
297            _cache[cache_key] = upstream
298        if created_setop and on_node:
299            on_node(upstream)
300        return upstream
301
302    if trim_selects and isinstance(scope.expression, exp.Select):
303        # For better ergonomics in our node labels, replace the full select with
304        # a version that has only the column we care about.
305        #   "x", SELECT x, y FROM foo
306        #     => "x", SELECT x FROM foo
307        source: exp.Expr = scope.expression.select(select, append=False)
308    else:
309        source = scope.expression
310
311    # Create the node for this step in the lineage chain, and attach it to the previous one.
312    node = Node(
313        name=f"{scope_name}.{column}" if scope_name else str(column),
314        source=source,
315        expression=select,
316        source_name=source_name or "",
317        reference_node_name=reference_node_name or "",
318    )
319
320    if upstream:
321        upstream.downstream.append(node)
322
323    subquery_scopes = {
324        id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes
325    }
326
327    for subquery in find_all_in_scope(select, *exp.UNWRAPPED_QUERIES):
328        subquery_scope: Scope | None = subquery_scopes.get(id(subquery))
329        if not subquery_scope:
330            logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
331            continue
332
333        for name in subquery.named_selects:
334            to_node(
335                name,
336                scope=subquery_scope,
337                dialect=dialect,
338                upstream=node,
339                trim_selects=trim_selects,
340                _cache=_cache,
341                _scope_meta=_scope_meta,
342                on_node=on_node,
343            )
344
345    # if the select is a star add all scope sources as downstreams
346    if isinstance(select, exp.Star):
347        for src in scope.sources.values():
348            src_expr = src.expression if isinstance(src, Scope) else src
349            star_node = Node(name=select.sql(comments=False), source=src_expr, expression=src_expr)
350            node.downstream.append(star_node)
351            if on_node:
352                on_node(star_node)
353
354    # Find all columns that went into creating this one to list their lineage nodes.
355    source_columns = set(find_all_in_scope(select, exp.Column))
356
357    # If the source is a UDTF find columns used in the UDTF to generate the table
358    if isinstance(source, exp.UDTF):
359        source_columns |= set(source.find_all(exp.Column))
360        derived_tables: Sequence[exp.Expr] = [
361            src.expression.parent
362            for src in scope.sources.values()
363            if isinstance(src, Scope) and src.is_derived_table and src.expression.parent
364        ]
365    else:
366        derived_tables = scope.derived_tables
367
368    source_names = {
369        dt.alias: dt.comments[0].split()[1]
370        for dt in derived_tables
371        if dt.comments and dt.comments[0].startswith("source: ")
372    }
373
374    pivots = scope.pivots
375    pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None
376    if pivot:
377        # For each aggregation function, the pivot creates a new column for each field in category
378        # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a,
379        # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum'
380        # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs
381        # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest
382        # in the lineage, so lookup the pivot column name by index and map that with the columns used
383        # in the aggregation.
384        #
385        # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b')
386        pivot_columns = pivot.args["columns"]
387        pivot_aggs_count = len(pivot.expressions)
388
389        pivot_column_mapping = {}
390        for i, agg in enumerate(pivot.expressions):
391            agg_cols = list(agg.find_all(exp.Column))
392            for col_index in range(i, len(pivot_columns), pivot_aggs_count):
393                pivot_column_mapping[pivot_columns[col_index].name] = agg_cols
394
395    for c in source_columns:
396        table = c.table
397        col_source: exp.Table | Scope | None = scope.sources.get(table)
398
399        if isinstance(col_source, Scope):
400            reference_node_name = None
401            if col_source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names:
402                reference_node_name = table
403            elif col_source.scope_type == ScopeType.CTE:
404                selected_node, _ = scope.selected_sources.get(table, (None, None))
405                reference_node_name = selected_node.name if selected_node else None
406
407            # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
408            to_node(
409                c.name,
410                scope=col_source,
411                dialect=dialect,
412                scope_name=table,
413                upstream=node,
414                source_name=source_names.get(table) or source_name,
415                reference_node_name=reference_node_name,
416                trim_selects=trim_selects,
417                _cache=_cache,
418                _scope_meta=_scope_meta,
419                on_node=on_node,
420            )
421        elif pivot and pivot.alias_or_name == c.table:
422            downstream_columns = []
423
424            column_name = c.name
425            if any(column_name == pivot_column.name for pivot_column in pivot_columns):
426                downstream_columns.extend(pivot_column_mapping[column_name])
427            else:
428                # The column is not in the pivot, so it must be an implicit column of the
429                # pivoted source -- adapt column to be from the implicit pivoted source.
430                pivot_parent = pivot.parent
431                downstream_columns.append(
432                    exp.column(c.this, table=pivot_parent.alias_or_name if pivot_parent else "")
433                )
434
435            for downstream_column in downstream_columns:
436                table = downstream_column.table
437                col_source = scope.sources.get(table)
438                if isinstance(col_source, Scope):
439                    to_node(
440                        downstream_column.name,
441                        scope=col_source,
442                        scope_name=table,
443                        dialect=dialect,
444                        upstream=node,
445                        source_name=source_names.get(table) or source_name,
446                        reference_node_name=reference_node_name,
447                        trim_selects=trim_selects,
448                        _cache=_cache,
449                        _scope_meta=_scope_meta,
450                        on_node=on_node,
451                    )
452                else:
453                    col_expr = col_source or exp.Placeholder()
454                    pivot_leaf = Node(
455                        name=downstream_column.sql(comments=False),
456                        source=col_expr,
457                        expression=col_expr,
458                    )
459                    node.downstream.append(pivot_leaf)
460                    if on_node:
461                        on_node(pivot_leaf)
462        else:
463            # The source is not a scope and the column is not in any pivot - we've reached the end
464            # of the line. At this point, if a source is not found it means this column's lineage
465            # is unknown. This can happen if the definition of a source used in a query is not
466            # passed into the `sources` map.
467            col_expr = col_source or exp.Placeholder()
468            leaf = Node(name=c.sql(comments=False), source=col_expr, expression=col_expr)
469            node.downstream.append(leaf)
470            if on_node:
471                on_node(leaf)
472
473    if _cache is not None:
474        _cache[cache_key] = node
475
476    if on_node:
477        on_node(node)
478
479    return node
480
481
482class GraphHTML:
483    """Node to HTML generator using vis.js.
484
485    https://visjs.github.io/vis-network/docs/network/
486    """
487
488    def __init__(
489        self,
490        nodes: dict,
491        edges: list,
492        imports: bool = True,
493        options: Mapping[str, object] | None = None,
494    ):
495        self.imports = imports
496
497        self.options = {
498            "height": "500px",
499            "width": "100%",
500            "layout": {
501                "hierarchical": {
502                    "enabled": True,
503                    "nodeSpacing": 200,
504                    "sortMethod": "directed",
505                },
506            },
507            "interaction": {
508                "dragNodes": False,
509                "selectable": False,
510            },
511            "physics": {
512                "enabled": False,
513            },
514            "edges": {
515                "arrows": "to",
516            },
517            "nodes": {
518                "font": "20px monaco",
519                "shape": "box",
520                "widthConstraint": {
521                    "maximum": 300,
522                },
523            },
524            **(options or {}),
525        }
526
527        self.nodes = nodes
528        self.edges = edges
529
530    def __str__(self):
531        nodes = json.dumps(list(self.nodes.values()))
532        edges = json.dumps(self.edges)
533        options = json.dumps(self.options)
534        imports = (
535            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
536  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
537  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
538            if self.imports
539            else ""
540        )
541
542        return f"""<div>
543  <div id="sqlglot-lineage"></div>
544  {imports}
545  <script type="text/javascript">
546    var nodes = new vis.DataSet({nodes})
547    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
548
549    new vis.Network(
550        document.getElementById("sqlglot-lineage"),
551        {{
552            nodes: nodes,
553            edges: new vis.DataSet({edges})
554        }},
555        {options},
556    )
557  </script>
558</div>"""
559
560    def _repr_html_(self) -> str:
561        return self.__str__()
logger = <Logger sqlglot (WARNING)>
@dataclass(frozen=True)
class Node:
23@dataclass(frozen=True)
24class Node:
25    name: str
26    expression: exp.Expr
27    source: exp.Expr
28    downstream: list[Node] = field(default_factory=list)
29    source_name: str = ""
30    reference_node_name: str = ""
31
32    # Caller-injected per-node data, populated via the `on_node` hook on lineage()
33    payload: dict[str, t.Any] = field(default_factory=dict)
34
35    def walk(self) -> Iterator[Node]:
36        visited: set[int] = set()
37        queue = [self]
38        while queue:
39            node = queue.pop()
40            node_id = id(node)
41            if node_id in visited:
42                continue
43            visited.add(node_id)
44            yield node
45            queue.extend(reversed(node.downstream))
46
47    def to_html(self, dialect: DialectType = None, **opts: Unpack[GraphHTMLArgs]) -> GraphHTML:
48        nodes = {}
49        edges = []
50
51        for node in self.walk():
52            if isinstance(node.expression, exp.Table):
53                label = f"FROM {node.expression.this}"
54                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
55                group = 1
56            else:
57                label = node.expression.sql(pretty=True, dialect=dialect)
58                source = node.source.transform(
59                    lambda n: (
60                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
61                    ),
62                    copy=False,
63                ).sql(pretty=True, dialect=dialect)
64                title = f"<pre>{source}</pre>"
65                group = 0
66
67            node_id = id(node)
68
69            nodes[node_id] = {
70                "id": node_id,
71                "label": label,
72                "title": title,
73                "group": group,
74            }
75
76            for d in node.downstream:
77                edges.append({"from": node_id, "to": id(d)})
78        return GraphHTML(nodes, edges, **opts)
Node( name: str, expression: sqlglot.expressions.core.Expr, source: sqlglot.expressions.core.Expr, downstream: list[Node] = <factory>, source_name: str = '', reference_node_name: str = '', payload: dict[str, typing.Any] = <factory>)
name: str
downstream: list[Node]
source_name: str = ''
reference_node_name: str = ''
payload: dict[str, typing.Any]
def walk(self) -> Iterator[Node]:
35    def walk(self) -> Iterator[Node]:
36        visited: set[int] = set()
37        queue = [self]
38        while queue:
39            node = queue.pop()
40            node_id = id(node)
41            if node_id in visited:
42                continue
43            visited.add(node_id)
44            yield node
45            queue.extend(reversed(node.downstream))
def to_html( self, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, **opts: typing_extensions.Unpack[sqlglot._typing.GraphHTMLArgs]) -> GraphHTML:
47    def to_html(self, dialect: DialectType = None, **opts: Unpack[GraphHTMLArgs]) -> GraphHTML:
48        nodes = {}
49        edges = []
50
51        for node in self.walk():
52            if isinstance(node.expression, exp.Table):
53                label = f"FROM {node.expression.this}"
54                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
55                group = 1
56            else:
57                label = node.expression.sql(pretty=True, dialect=dialect)
58                source = node.source.transform(
59                    lambda n: (
60                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
61                    ),
62                    copy=False,
63                ).sql(pretty=True, dialect=dialect)
64                title = f"<pre>{source}</pre>"
65                group = 0
66
67            node_id = id(node)
68
69            nodes[node_id] = {
70                "id": node_id,
71                "label": label,
72                "title": title,
73                "group": group,
74            }
75
76            for d in node.downstream:
77                edges.append({"from": node_id, "to": id(d)})
78        return GraphHTML(nodes, edges, **opts)
def lineage( column: str | sqlglot.expressions.core.Column | None, sql: str | sqlglot.expressions.core.Expr, schema: dict | sqlglot.schema.Schema | None = None, sources: Mapping[str, str | sqlglot.expressions.query.Query] | None = None, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, scope: sqlglot.optimizer.scope.Scope | None = None, trim_selects: bool = True, copy: bool = True, on_node: Optional[Callable[[Node], NoneType]] = None, **kwargs) -> Node | dict[str, Node]:
 89def lineage(
 90    column: str | exp.Column | None,
 91    sql: str | exp.Expr,
 92    schema: dict | Schema | None = None,
 93    sources: Mapping[str, str | exp.Query] | None = None,
 94    dialect: DialectType = None,
 95    scope: Scope | None = None,
 96    trim_selects: bool = True,
 97    copy: bool = True,
 98    on_node: t.Callable[[Node], None] | None = None,
 99    **kwargs,
100) -> Node | dict[str, Node]:
101    """Build the lineage graph for a SQL query.
102
103    If `column` is given, returns the lineage Node for that single output column.
104    If `column` is None, returns a dict mapping every top-level output column name
105    to its lineage Node (with a shared cache so cross-column work is deduplicated).
106
107    Args:
108        column: The column to build the lineage for. Pass None to get all output columns.
109        sql: The SQL string or expression.
110        schema: The schema of tables.
111        sources: A mapping of queries which will be used to continue building lineage.
112        dialect: The dialect of input SQL.
113        scope: A pre-created scope to use instead.
114        trim_selects: Whether to clean up selects by trimming to only relevant columns.
115        copy: Whether to copy the Expr arguments.
116        on_node: Optional callback invoked for every Node created during the walk,
117            after the Node's downstream is populated. Useful for injecting
118            caller-managed data into Node.payload during the walk.
119        **kwargs: Qualification optimizer kwargs.
120
121    Returns:
122        A Node when `column` is provided, or a dict[str, Node] when `column` is None.
123    """
124    expression = maybe_parse(sql, copy=copy, dialect=dialect)
125
126    if sources:
127        expression = exp.expand(
128            expression,
129            {
130                k: t.cast(exp.Query, maybe_parse(v, copy=copy, dialect=dialect))
131                for k, v in sources.items()
132            },
133            dialect=dialect,
134            copy=copy,
135        )
136
137    if not scope:
138        expression = qualify.qualify(
139            expression,
140            dialect=dialect,
141            schema=schema,
142            **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
143        )
144        scope = build_scope(expression)
145
146    if not scope:
147        raise SqlglotError("Cannot build lineage, sql must be SELECT")
148
149    selectable = scope.expression
150    if not isinstance(selectable, exp.Selectable):
151        raise SqlglotError("Cannot build lineage, sql must be a query")
152
153    cache: dict[tuple, Node] = {}
154    scope_meta: dict[int, tuple[bool, dict[str, exp.Expr]]] = {}
155
156    if column is not None:
157        column_name = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
158        if not any(select.alias_or_name == column_name for select in selectable.selects):
159            raise SqlglotError(f"Cannot find column '{column_name}' in query.")
160
161        return to_node(
162            column_name,
163            scope,
164            dialect,
165            trim_selects=trim_selects,
166            _cache=cache,
167            _scope_meta=scope_meta,
168            on_node=on_node,
169        )
170
171    result: dict[str, Node] = {}
172    for sel in selectable.selects:
173        name = sel.alias_or_name
174        if not name:
175            raise SqlglotError(
176                f"Cannot fetch lineage for unnamed projection: {sel.sql(dialect=dialect)}."
177            )
178
179        result[name] = to_node(
180            name,
181            scope,
182            dialect,
183            trim_selects=trim_selects,
184            _cache=cache,
185            _scope_meta=scope_meta,
186            on_node=on_node,
187        )
188
189    return result

Build the lineage graph for a SQL query.

If column is given, returns the lineage Node for that single output column. If column is None, returns a dict mapping every top-level output column name to its lineage Node (with a shared cache so cross-column work is deduplicated).

Arguments:
  • column: The column to build the lineage for. Pass None to get all output columns.
  • sql: The SQL string or expression.
  • schema: The schema of tables.
  • sources: A mapping of queries which will be used to continue building lineage.
  • dialect: The dialect of input SQL.
  • scope: A pre-created scope to use instead.
  • trim_selects: Whether to clean up selects by trimming to only relevant columns.
  • copy: Whether to copy the Expr arguments.
  • on_node: Optional callback invoked for every Node created during the walk, after the Node's downstream is populated. Useful for injecting caller-managed data into Node.payload during the walk.
  • **kwargs: Qualification optimizer kwargs.
Returns:

A Node when column is provided, or a dict[str, Node] when column is None.

def to_node( column: str | int, scope: sqlglot.optimizer.scope.Scope, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType], scope_name: str | None = None, upstream: Node | None = None, source_name: str | None = None, reference_node_name: str | None = None, trim_selects: bool = True, _cache: dict[tuple, Node] | None = None, _scope_meta: dict[int, tuple[bool, dict[str, sqlglot.expressions.core.Expr]]] | None = None, on_node: Optional[Callable[[Node], NoneType]] = None) -> Node:
192def to_node(
193    column: str | int,
194    scope: Scope,
195    dialect: DialectType,
196    scope_name: str | None = None,
197    upstream: Node | None = None,
198    source_name: str | None = None,
199    reference_node_name: str | None = None,
200    trim_selects: bool = True,
201    _cache: dict[tuple, Node] | None = None,
202    _scope_meta: dict[int, tuple[bool, dict[str, exp.Expr]]] | None = None,
203    on_node: t.Callable[[Node], None] | None = None,
204) -> Node:
205    cache_key = (column, id(scope), scope_name, source_name, reference_node_name)
206
207    if _cache is not None and cache_key in _cache:
208        cached_node = _cache[cache_key]
209        if upstream:
210            upstream.downstream.append(cached_node)
211        return cached_node
212
213    # Find the specific select clause that is the source of the column we want.
214    # This can either be a specific, named select or a generic `*` clause.
215    selectable = t.cast(exp.Selectable, scope.expression)
216    if isinstance(column, int):
217        if column >= len(selectable.selects):
218            raise SqlglotError(
219                f"Cannot find column's source with index {column} in query: {selectable.sql(dialect=dialect)}"
220            )
221        select = selectable.selects[column]
222    else:
223        # Resolving a column to its select scans selectable.selects on every call;
224        # memoize a per-scope {name: select} map and is_star bit instead.
225        if _scope_meta is None:
226            select = next(
227                (s for s in selectable.selects if s.alias_or_name == column),
228                exp.Star() if selectable.is_star else scope.expression,
229            )
230        else:
231            scope_id = id(scope)
232            meta = _scope_meta.get(scope_id)
233            if meta is None:
234                select_by_name: dict[str, exp.Expr] = {}
235                for sel in selectable.selects:
236                    select_by_name.setdefault(sel.alias_or_name, sel)
237                meta = (selectable.is_star, select_by_name)
238                _scope_meta[scope_id] = meta
239            is_star, select_by_name = meta
240            select = select_by_name.get(column, exp.Star() if is_star else scope.expression)
241
242    if isinstance(scope.expression, exp.Subquery):
243        for inner_scope in scope.subquery_scopes:
244            result = to_node(
245                column,
246                scope=inner_scope,
247                dialect=dialect,
248                upstream=upstream,
249                source_name=source_name,
250                reference_node_name=reference_node_name,
251                trim_selects=trim_selects,
252                _cache=_cache,
253                _scope_meta=_scope_meta,
254                on_node=on_node,
255            )
256            # Skip caching a passed-in upstream returned by an inner SetOp:
257            # a sibling call at the same key with that node as its upstream
258            # would otherwise self-loop on the cache hit.
259            if _cache is not None and result is not upstream:
260                _cache[cache_key] = result
261            return result
262    if isinstance(scope.expression, exp.SetOperation):
263        name = type(scope.expression).__name__.upper()
264        created_setop = upstream is None
265        upstream = upstream or Node(name=name, source=scope.expression, expression=select)
266
267        index = (
268            column
269            if isinstance(column, int)
270            else next(
271                (
272                    i
273                    for i, select in enumerate(selectable.selects)
274                    if select.alias_or_name == column or select.is_star
275                ),
276                -1,  # mypy will not allow a None here, but a negative index should never be returned
277            )
278        )
279
280        if index == -1:
281            raise ValueError(f"Could not find {column} in {scope.expression}")
282
283        for s in scope.union_scopes:
284            to_node(
285                index,
286                scope=s,
287                dialect=dialect,
288                upstream=upstream,
289                source_name=source_name,
290                reference_node_name=reference_node_name,
291                trim_selects=trim_selects,
292                _cache=_cache,
293                _scope_meta=_scope_meta,
294                on_node=on_node,
295            )
296
297        if _cache is not None and created_setop:
298            _cache[cache_key] = upstream
299        if created_setop and on_node:
300            on_node(upstream)
301        return upstream
302
303    if trim_selects and isinstance(scope.expression, exp.Select):
304        # For better ergonomics in our node labels, replace the full select with
305        # a version that has only the column we care about.
306        #   "x", SELECT x, y FROM foo
307        #     => "x", SELECT x FROM foo
308        source: exp.Expr = scope.expression.select(select, append=False)
309    else:
310        source = scope.expression
311
312    # Create the node for this step in the lineage chain, and attach it to the previous one.
313    node = Node(
314        name=f"{scope_name}.{column}" if scope_name else str(column),
315        source=source,
316        expression=select,
317        source_name=source_name or "",
318        reference_node_name=reference_node_name or "",
319    )
320
321    if upstream:
322        upstream.downstream.append(node)
323
324    subquery_scopes = {
325        id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes
326    }
327
328    for subquery in find_all_in_scope(select, *exp.UNWRAPPED_QUERIES):
329        subquery_scope: Scope | None = subquery_scopes.get(id(subquery))
330        if not subquery_scope:
331            logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
332            continue
333
334        for name in subquery.named_selects:
335            to_node(
336                name,
337                scope=subquery_scope,
338                dialect=dialect,
339                upstream=node,
340                trim_selects=trim_selects,
341                _cache=_cache,
342                _scope_meta=_scope_meta,
343                on_node=on_node,
344            )
345
346    # if the select is a star add all scope sources as downstreams
347    if isinstance(select, exp.Star):
348        for src in scope.sources.values():
349            src_expr = src.expression if isinstance(src, Scope) else src
350            star_node = Node(name=select.sql(comments=False), source=src_expr, expression=src_expr)
351            node.downstream.append(star_node)
352            if on_node:
353                on_node(star_node)
354
355    # Find all columns that went into creating this one to list their lineage nodes.
356    source_columns = set(find_all_in_scope(select, exp.Column))
357
358    # If the source is a UDTF find columns used in the UDTF to generate the table
359    if isinstance(source, exp.UDTF):
360        source_columns |= set(source.find_all(exp.Column))
361        derived_tables: Sequence[exp.Expr] = [
362            src.expression.parent
363            for src in scope.sources.values()
364            if isinstance(src, Scope) and src.is_derived_table and src.expression.parent
365        ]
366    else:
367        derived_tables = scope.derived_tables
368
369    source_names = {
370        dt.alias: dt.comments[0].split()[1]
371        for dt in derived_tables
372        if dt.comments and dt.comments[0].startswith("source: ")
373    }
374
375    pivots = scope.pivots
376    pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None
377    if pivot:
378        # For each aggregation function, the pivot creates a new column for each field in category
379        # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a,
380        # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum'
381        # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs
382        # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest
383        # in the lineage, so lookup the pivot column name by index and map that with the columns used
384        # in the aggregation.
385        #
386        # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b')
387        pivot_columns = pivot.args["columns"]
388        pivot_aggs_count = len(pivot.expressions)
389
390        pivot_column_mapping = {}
391        for i, agg in enumerate(pivot.expressions):
392            agg_cols = list(agg.find_all(exp.Column))
393            for col_index in range(i, len(pivot_columns), pivot_aggs_count):
394                pivot_column_mapping[pivot_columns[col_index].name] = agg_cols
395
396    for c in source_columns:
397        table = c.table
398        col_source: exp.Table | Scope | None = scope.sources.get(table)
399
400        if isinstance(col_source, Scope):
401            reference_node_name = None
402            if col_source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names:
403                reference_node_name = table
404            elif col_source.scope_type == ScopeType.CTE:
405                selected_node, _ = scope.selected_sources.get(table, (None, None))
406                reference_node_name = selected_node.name if selected_node else None
407
408            # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
409            to_node(
410                c.name,
411                scope=col_source,
412                dialect=dialect,
413                scope_name=table,
414                upstream=node,
415                source_name=source_names.get(table) or source_name,
416                reference_node_name=reference_node_name,
417                trim_selects=trim_selects,
418                _cache=_cache,
419                _scope_meta=_scope_meta,
420                on_node=on_node,
421            )
422        elif pivot and pivot.alias_or_name == c.table:
423            downstream_columns = []
424
425            column_name = c.name
426            if any(column_name == pivot_column.name for pivot_column in pivot_columns):
427                downstream_columns.extend(pivot_column_mapping[column_name])
428            else:
429                # The column is not in the pivot, so it must be an implicit column of the
430                # pivoted source -- adapt column to be from the implicit pivoted source.
431                pivot_parent = pivot.parent
432                downstream_columns.append(
433                    exp.column(c.this, table=pivot_parent.alias_or_name if pivot_parent else "")
434                )
435
436            for downstream_column in downstream_columns:
437                table = downstream_column.table
438                col_source = scope.sources.get(table)
439                if isinstance(col_source, Scope):
440                    to_node(
441                        downstream_column.name,
442                        scope=col_source,
443                        scope_name=table,
444                        dialect=dialect,
445                        upstream=node,
446                        source_name=source_names.get(table) or source_name,
447                        reference_node_name=reference_node_name,
448                        trim_selects=trim_selects,
449                        _cache=_cache,
450                        _scope_meta=_scope_meta,
451                        on_node=on_node,
452                    )
453                else:
454                    col_expr = col_source or exp.Placeholder()
455                    pivot_leaf = Node(
456                        name=downstream_column.sql(comments=False),
457                        source=col_expr,
458                        expression=col_expr,
459                    )
460                    node.downstream.append(pivot_leaf)
461                    if on_node:
462                        on_node(pivot_leaf)
463        else:
464            # The source is not a scope and the column is not in any pivot - we've reached the end
465            # of the line. At this point, if a source is not found it means this column's lineage
466            # is unknown. This can happen if the definition of a source used in a query is not
467            # passed into the `sources` map.
468            col_expr = col_source or exp.Placeholder()
469            leaf = Node(name=c.sql(comments=False), source=col_expr, expression=col_expr)
470            node.downstream.append(leaf)
471            if on_node:
472                on_node(leaf)
473
474    if _cache is not None:
475        _cache[cache_key] = node
476
477    if on_node:
478        on_node(node)
479
480    return node
class GraphHTML:
483class GraphHTML:
484    """Node to HTML generator using vis.js.
485
486    https://visjs.github.io/vis-network/docs/network/
487    """
488
489    def __init__(
490        self,
491        nodes: dict,
492        edges: list,
493        imports: bool = True,
494        options: Mapping[str, object] | None = None,
495    ):
496        self.imports = imports
497
498        self.options = {
499            "height": "500px",
500            "width": "100%",
501            "layout": {
502                "hierarchical": {
503                    "enabled": True,
504                    "nodeSpacing": 200,
505                    "sortMethod": "directed",
506                },
507            },
508            "interaction": {
509                "dragNodes": False,
510                "selectable": False,
511            },
512            "physics": {
513                "enabled": False,
514            },
515            "edges": {
516                "arrows": "to",
517            },
518            "nodes": {
519                "font": "20px monaco",
520                "shape": "box",
521                "widthConstraint": {
522                    "maximum": 300,
523                },
524            },
525            **(options or {}),
526        }
527
528        self.nodes = nodes
529        self.edges = edges
530
531    def __str__(self):
532        nodes = json.dumps(list(self.nodes.values()))
533        edges = json.dumps(self.edges)
534        options = json.dumps(self.options)
535        imports = (
536            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
537  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
538  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
539            if self.imports
540            else ""
541        )
542
543        return f"""<div>
544  <div id="sqlglot-lineage"></div>
545  {imports}
546  <script type="text/javascript">
547    var nodes = new vis.DataSet({nodes})
548    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
549
550    new vis.Network(
551        document.getElementById("sqlglot-lineage"),
552        {{
553            nodes: nodes,
554            edges: new vis.DataSet({edges})
555        }},
556        {options},
557    )
558  </script>
559</div>"""
560
561    def _repr_html_(self) -> str:
562        return self.__str__()

Node to HTML generator using vis.js.

https://visjs.github.io/vis-network/docs/network/

GraphHTML( nodes: dict, edges: list, imports: bool = True, options: Mapping[str, object] | None = None)
489    def __init__(
490        self,
491        nodes: dict,
492        edges: list,
493        imports: bool = True,
494        options: Mapping[str, object] | None = None,
495    ):
496        self.imports = imports
497
498        self.options = {
499            "height": "500px",
500            "width": "100%",
501            "layout": {
502                "hierarchical": {
503                    "enabled": True,
504                    "nodeSpacing": 200,
505                    "sortMethod": "directed",
506                },
507            },
508            "interaction": {
509                "dragNodes": False,
510                "selectable": False,
511            },
512            "physics": {
513                "enabled": False,
514            },
515            "edges": {
516                "arrows": "to",
517            },
518            "nodes": {
519                "font": "20px monaco",
520                "shape": "box",
521                "widthConstraint": {
522                    "maximum": 300,
523                },
524            },
525            **(options or {}),
526        }
527
528        self.nodes = nodes
529        self.edges = edges
imports
options
nodes
edges