  1from __future__ import annotations
  3import json
  4import logging
  5import typing as t
  6from dataclasses import dataclass, field
  8from sqlglot import Schema, exp, maybe_parse
  9from sqlglot.errors import SqlglotError
 10from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify
 11from sqlglot.optimizer.scope import ScopeType
 14    from sqlglot.dialects.dialect import DialectType
 16logger = logging.getLogger("sqlglot")
 20class Node:
 21    name: str
 22    expression: exp.Expression
 23    source: exp.Expression
 24    downstream: t.List[Node] = field(default_factory=list)
 25    source_name: str = ""
 26    reference_node_name: str = ""
 28    def walk(self) -> t.Iterator[Node]:
 29        yield self
 31        for d in self.downstream:
 32            yield from d.walk()
 34    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
 35        nodes = {}
 36        edges = []
 38        for node in self.walk():
 39            if isinstance(node.expression, exp.Table):
 40                label = f"FROM {node.expression.this}"
 41                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
 42                group = 1
 43            else:
 44                label = node.expression.sql(pretty=True, dialect=dialect)
 45                source = node.source.transform(
 46                    lambda n: (
 47                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
 48                    ),
 49                    copy=False,
 50                ).sql(pretty=True, dialect=dialect)
 51                title = f"<pre>{source}</pre>"
 52                group = 0
 54            node_id = id(node)
 56            nodes[node_id] = {
 57                "id": node_id,
 58                "label": label,
 59                "title": title,
 60                "group": group,
 61            }
 63            for d in node.downstream:
 64                edges.append({"from": node_id, "to": id(d)})
 65        return GraphHTML(nodes, edges, **opts)
 68def lineage(
 69    column: str | exp.Column,
 70    sql: str | exp.Expression,
 71    schema: t.Optional[t.Dict | Schema] = None,
 72    sources: t.Optional[t.Mapping[str, str | exp.Query]] = None,
 73    dialect: DialectType = None,
 74    scope: t.Optional[Scope] = None,
 75    trim_selects: bool = True,
 76    **kwargs,
 77) -> Node:
 78    """Build the lineage graph for a column of a SQL query.
 80    Args:
 81        column: The column to build the lineage for.
 82        sql: The SQL string or expression.
 83        schema: The schema of tables.
 84        sources: A mapping of queries which will be used to continue building lineage.
 85        dialect: The dialect of input SQL.
 86        scope: A pre-created scope to use instead.
 87        trim_selects: Whether or not to clean up selects by trimming to only relevant columns.
 88        **kwargs: Qualification optimizer kwargs.
 90    Returns:
 91        A lineage node.
 92    """
 94    expression = maybe_parse(sql, dialect=dialect)
 95    column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
 97    if sources:
 98        expression = exp.expand(
 99            expression,
100            {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()},
101            dialect=dialect,
102        )
104    if not scope:
105        expression = qualify.qualify(
106            expression,
107            dialect=dialect,
108            schema=schema,
109            **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
110        )
112        scope = build_scope(expression)
114    if not scope:
115        raise SqlglotError("Cannot build lineage, sql must be SELECT")
117    if not any(select.alias_or_name == column for select in scope.expression.selects):
118        raise SqlglotError(f"Cannot find column '{column}' in query.")
120    return to_node(column, scope, dialect, trim_selects=trim_selects)
123def to_node(
124    column: str | int,
125    scope: Scope,
126    dialect: DialectType,
127    scope_name: t.Optional[str] = None,
128    upstream: t.Optional[Node] = None,
129    source_name: t.Optional[str] = None,
130    reference_node_name: t.Optional[str] = None,
131    trim_selects: bool = True,
132) -> Node:
133    # Find the specific select clause that is the source of the column we want.
134    # This can either be a specific, named select or a generic `*` clause.
135    select = (
136        scope.expression.selects[column]
137        if isinstance(column, int)
138        else next(
139            (select for select in scope.expression.selects if select.alias_or_name == column),
140            exp.Star() if scope.expression.is_star else scope.expression,
141        )
142    )
144    if isinstance(scope.expression, exp.Subquery):
145        for source in scope.subquery_scopes:
146            return to_node(
147                column,
148                scope=source,
149                dialect=dialect,
150                upstream=upstream,
151                source_name=source_name,
152                reference_node_name=reference_node_name,
153                trim_selects=trim_selects,
154            )
155    if isinstance(scope.expression, exp.SetOperation):
156        name = type(scope.expression).__name__.upper()
157        upstream = upstream or Node(name=name, source=scope.expression, expression=select)
159        index = (
160            column
161            if isinstance(column, int)
162            else next(
163                (
164                    i
165                    for i, select in enumerate(scope.expression.selects)
166                    if select.alias_or_name == column or select.is_star
167                ),
168                -1,  # mypy will not allow a None here, but a negative index should never be returned
169            )
170        )
172        if index == -1:
173            raise ValueError(f"Could not find {column} in {scope.expression}")
175        for s in scope.union_scopes:
176            to_node(
177                index,
178                scope=s,
179                dialect=dialect,
180                upstream=upstream,
181                source_name=source_name,
182                reference_node_name=reference_node_name,
183                trim_selects=trim_selects,
184            )
186        return upstream
188    if trim_selects and isinstance(scope.expression, exp.Select):
189        # For better ergonomics in our node labels, replace the full select with
190        # a version that has only the column we care about.
191        #   "x", SELECT x, y FROM foo
192        #     => "x", SELECT x FROM foo
193        source = t.cast(exp.Expression, scope.expression.select(select, append=False))
194    else:
195        source = scope.expression
197    # Create the node for this step in the lineage chain, and attach it to the previous one.
198    node = Node(
199        name=f"{scope_name}.{column}" if scope_name else str(column),
200        source=source,
201        expression=select,
202        source_name=source_name or "",
203        reference_node_name=reference_node_name or "",
204    )
206    if upstream:
207        upstream.downstream.append(node)
209    subquery_scopes = {
210        id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes
211    }
213    for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES):
214        subquery_scope = subquery_scopes.get(id(subquery))
215        if not subquery_scope:
216            logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
217            continue
219        for name in subquery.named_selects:
220            to_node(
221                name,
222                scope=subquery_scope,
223                dialect=dialect,
224                upstream=node,
225                trim_selects=trim_selects,
226            )
228    # if the select is a star add all scope sources as downstreams
229    if select.is_star:
230        for source in scope.sources.values():
231            if isinstance(source, Scope):
232                source = source.expression
233            node.downstream.append(
234                Node(name=select.sql(comments=False), source=source, expression=source)
235            )
237    # Find all columns that went into creating this one to list their lineage nodes.
238    source_columns = set(find_all_in_scope(select, exp.Column))
240    # If the source is a UDTF find columns used in the UTDF to generate the table
241    if isinstance(source, exp.UDTF):
242        source_columns |= set(source.find_all(exp.Column))
243        derived_tables = [
244            source.expression.parent
245            for source in scope.sources.values()
246            if isinstance(source, Scope) and source.is_derived_table
247        ]
248    else:
249        derived_tables = scope.derived_tables
251    source_names = {
252        dt.alias: dt.comments[0].split()[1]
253        for dt in derived_tables
254        if dt.comments and dt.comments[0].startswith("source: ")
255    }
257    for c in source_columns:
258        table = c.table
259        source = scope.sources.get(table)
261        if isinstance(source, Scope):
262            reference_node_name = None
263            if source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names:
264                reference_node_name = table
265            elif source.scope_type == ScopeType.CTE:
266                selected_node, _ = scope.selected_sources.get(table, (None, None))
267                reference_node_name = selected_node.name if selected_node else None
268            # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
269            to_node(
270                c.name,
271                scope=source,
272                dialect=dialect,
273                scope_name=table,
274                upstream=node,
275                source_name=source_names.get(table) or source_name,
276                reference_node_name=reference_node_name,
277                trim_selects=trim_selects,
278            )
279        else:
280            # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
281            # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
282            # is not passed into the `sources` map.
283            source = source or exp.Placeholder()
284            node.downstream.append(
285                Node(name=c.sql(comments=False), source=source, expression=source)
286            )
288    return node
291class GraphHTML:
292    """Node to HTML generator using vis.js.
294    https://visjs.github.io/vis-network/docs/network/
295    """
297    def __init__(
298        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
299    ):
300        self.imports = imports
302        self.options = {
303            "height": "500px",
304            "width": "100%",
305            "layout": {
306                "hierarchical": {
307                    "enabled": True,
308                    "nodeSpacing": 200,
309                    "sortMethod": "directed",
310                },
311            },
312            "interaction": {
313                "dragNodes": False,
314                "selectable": False,
315            },
316            "physics": {
317                "enabled": False,
318            },
319            "edges": {
320                "arrows": "to",
321            },
322            "nodes": {
323                "font": "20px monaco",
324                "shape": "box",
325                "widthConstraint": {
326                    "maximum": 300,
327                },
328            },
329            **(options or {}),
330        }
332        self.nodes = nodes
333        self.edges = edges
335    def __str__(self):
336        nodes = json.dumps(list(self.nodes.values()))
337        edges = json.dumps(self.edges)
338        options = json.dumps(self.options)
339        imports = (
340            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
341  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
342  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
343            if self.imports
344            else ""
345        )
347        return f"""<div>
348  <div id="sqlglot-lineage"></div>
349  {imports}
350  <script type="text/javascript">
351    var nodes = new vis.DataSet({nodes})
352    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
354    new vis.Network(
355        document.getElementById("sqlglot-lineage"),
356        {{
357            nodes: nodes,
358            edges: new vis.DataSet({edges})
359        }},
360        {options},
361    )
362  </script>
365    def _repr_html_(self) -> str:
366        return self.__str__()
