Edit on GitHub

sqlglot.lineage

  1from __future__ import annotations
  2
  3import json
  4import typing as t
  5from dataclasses import dataclass, field
  6
  7from sqlglot import Schema, exp, maybe_parse
  8from sqlglot.errors import SqlglotError
  9from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, qualify
 10
 11if t.TYPE_CHECKING:
 12    from sqlglot.dialects.dialect import DialectType
 13
 14
 15@dataclass(frozen=True)
 16class Node:
 17    name: str
 18    expression: exp.Expression
 19    source: exp.Expression
 20    downstream: t.List[Node] = field(default_factory=list)
 21    alias: str = ""
 22
 23    def walk(self) -> t.Iterator[Node]:
 24        yield self
 25
 26        for d in self.downstream:
 27            if isinstance(d, Node):
 28                yield from d.walk()
 29            else:
 30                yield d
 31
 32    def to_html(self, **opts) -> LineageHTML:
 33        return LineageHTML(self, **opts)
 34
 35
 36def lineage(
 37    column: str | exp.Column,
 38    sql: str | exp.Expression,
 39    schema: t.Optional[t.Dict | Schema] = None,
 40    sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
 41    dialect: DialectType = None,
 42    **kwargs,
 43) -> Node:
 44    """Build the lineage graph for a column of a SQL query.
 45
 46    Args:
 47        column: The column to build the lineage for.
 48        sql: The SQL string or expression.
 49        schema: The schema of tables.
 50        sources: A mapping of queries which will be used to continue building lineage.
 51        dialect: The dialect of input SQL.
 52        **kwargs: Qualification optimizer kwargs.
 53
 54    Returns:
 55        A lineage node.
 56    """
 57
 58    expression = maybe_parse(sql, dialect=dialect)
 59
 60    if sources:
 61        expression = exp.expand(
 62            expression,
 63            {
 64                k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
 65                for k, v in sources.items()
 66            },
 67            dialect=dialect,
 68        )
 69
 70    qualified = qualify.qualify(
 71        expression,
 72        dialect=dialect,
 73        schema=schema,
 74        **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
 75    )
 76
 77    scope = build_scope(qualified)
 78
 79    if not scope:
 80        raise SqlglotError("Cannot build lineage, sql must be SELECT")
 81
 82    def to_node(
 83        column: str | int,
 84        scope: Scope,
 85        scope_name: t.Optional[str] = None,
 86        upstream: t.Optional[Node] = None,
 87        alias: t.Optional[str] = None,
 88    ) -> Node:
 89        aliases = {
 90            dt.alias: dt.comments[0].split()[1]
 91            for dt in scope.derived_tables
 92            if dt.comments and dt.comments[0].startswith("source: ")
 93        }
 94
 95        # Find the specific select clause that is the source of the column we want.
 96        # This can either be a specific, named select or a generic `*` clause.
 97        select = (
 98            scope.expression.selects[column]
 99            if isinstance(column, int)
100            else next(
101                (select for select in scope.expression.selects if select.alias_or_name == column),
102                exp.Star() if scope.expression.is_star else None,
103            )
104        )
105
106        if not select:
107            raise ValueError(f"Could not find {column} in {scope.expression}")
108
109        if isinstance(scope.expression, exp.Union):
110            upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
111
112            index = (
113                column
114                if isinstance(column, int)
115                else next(
116                    (
117                        i
118                        for i, select in enumerate(scope.expression.selects)
119                        if select.alias_or_name == column or select.is_star
120                    ),
121                    -1,  # mypy will not allow a None here, but a negative index should never be returned
122                )
123            )
124
125            if index == -1:
126                raise ValueError(f"Could not find {column} in {scope.expression}")
127
128            for s in scope.union_scopes:
129                to_node(index, scope=s, upstream=upstream)
130
131            return upstream
132
133        if isinstance(scope.expression, exp.Select):
134            # For better ergonomics in our node labels, replace the full select with
135            # a version that has only the column we care about.
136            #   "x", SELECT x, y FROM foo
137            #     => "x", SELECT x FROM foo
138            source = t.cast(exp.Expression, scope.expression.select(select, append=False))
139        else:
140            source = scope.expression
141
142        # Create the node for this step in the lineage chain, and attach it to the previous one.
143        node = Node(
144            name=f"{scope_name}.{column}" if scope_name else str(column),
145            source=source,
146            expression=select,
147            alias=alias or "",
148        )
149
150        if upstream:
151            upstream.downstream.append(node)
152
153        subquery_scopes = {
154            id(subquery_scope.expression): subquery_scope
155            for subquery_scope in scope.subquery_scopes
156        }
157
158        for subquery in find_all_in_scope(select, exp.Subqueryable):
159            subquery_scope = subquery_scopes[id(subquery)]
160
161            for name in subquery.named_selects:
162                to_node(name, scope=subquery_scope, upstream=node)
163
164        # if the select is a star add all scope sources as downstreams
165        if select.is_star:
166            for source in scope.sources.values():
167                node.downstream.append(Node(name=select.sql(), source=source, expression=source))
168
169        # Find all columns that went into creating this one to list their lineage nodes.
170        source_columns = set(find_all_in_scope(select, exp.Column))
171
172        # If the source is a UDTF find columns used in the UTDF to generate the table
173        if isinstance(source, exp.UDTF):
174            source_columns |= set(source.find_all(exp.Column))
175
176        for c in source_columns:
177            table = c.table
178            source = scope.sources.get(table)
179
180            if isinstance(source, Scope):
181                # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
182                to_node(
183                    c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
184                )
185            else:
186                # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
187                # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
188                # is not passed into the `sources` map.
189                source = source or exp.Placeholder()
190                node.downstream.append(Node(name=c.sql(), source=source, expression=source))
191
192        return node
193
194    return to_node(column if isinstance(column, str) else column.name, scope)
195
196
197class LineageHTML:
198    """Node to HTML generator using vis.js.
199
200    https://visjs.github.io/vis-network/docs/network/
201    """
202
203    def __init__(
204        self,
205        node: Node,
206        dialect: DialectType = None,
207        imports: bool = True,
208        **opts: t.Any,
209    ):
210        self.node = node
211        self.imports = imports
212
213        self.options = {
214            "height": "500px",
215            "width": "100%",
216            "layout": {
217                "hierarchical": {
218                    "enabled": True,
219                    "nodeSpacing": 200,
220                    "sortMethod": "directed",
221                },
222            },
223            "interaction": {
224                "dragNodes": False,
225                "selectable": False,
226            },
227            "physics": {
228                "enabled": False,
229            },
230            "edges": {
231                "arrows": "to",
232            },
233            "nodes": {
234                "font": "20px monaco",
235                "shape": "box",
236                "widthConstraint": {
237                    "maximum": 300,
238                },
239            },
240            **opts,
241        }
242
243        self.nodes = {}
244        self.edges = []
245
246        for node in node.walk():
247            if isinstance(node.expression, exp.Table):
248                label = f"FROM {node.expression.this}"
249                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
250                group = 1
251            else:
252                label = node.expression.sql(pretty=True, dialect=dialect)
253                source = node.source.transform(
254                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
255                    if n is node.expression
256                    else n,
257                    copy=False,
258                ).sql(pretty=True, dialect=dialect)
259                title = f"<pre>{source}</pre>"
260                group = 0
261
262            node_id = id(node)
263
264            self.nodes[node_id] = {
265                "id": node_id,
266                "label": label,
267                "title": title,
268                "group": group,
269            }
270
271            for d in node.downstream:
272                self.edges.append({"from": node_id, "to": id(d)})
273
274    def __str__(self):
275        nodes = json.dumps(list(self.nodes.values()))
276        edges = json.dumps(self.edges)
277        options = json.dumps(self.options)
278        imports = (
279            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
280  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
281  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
282            if self.imports
283            else ""
284        )
285
286        return f"""<div>
287  <div id="sqlglot-lineage"></div>
288  {imports}
289  <script type="text/javascript">
290    var nodes = new vis.DataSet({nodes})
291    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
292
293    new vis.Network(
294        document.getElementById("sqlglot-lineage"),
295        {{
296            nodes: nodes,
297            edges: new vis.DataSet({edges})
298        }},
299        {options},
300    )
301  </script>
302</div>"""
303
304    def _repr_html_(self) -> str:
305        return self.__str__()
@dataclass(frozen=True)
class Node:
16@dataclass(frozen=True)
17class Node:
18    name: str
19    expression: exp.Expression
20    source: exp.Expression
21    downstream: t.List[Node] = field(default_factory=list)
22    alias: str = ""
23
24    def walk(self) -> t.Iterator[Node]:
25        yield self
26
27        for d in self.downstream:
28            if isinstance(d, Node):
29                yield from d.walk()
30            else:
31                yield d
32
33    def to_html(self, **opts) -> LineageHTML:
34        return LineageHTML(self, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[Node] = <factory>, alias: str = '')
name: str
downstream: List[Node]
alias: str = ''
def walk(self) -> Iterator[Node]:
24    def walk(self) -> t.Iterator[Node]:
25        yield self
26
27        for d in self.downstream:
28            if isinstance(d, Node):
29                yield from d.walk()
30            else:
31                yield d
def to_html(self, **opts) -> LineageHTML:
33    def to_html(self, **opts) -> LineageHTML:
34        return LineageHTML(self, **opts)
def lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Dict[str, str | sqlglot.expressions.Subqueryable]] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **kwargs) -> Node:
 37def lineage(
 38    column: str | exp.Column,
 39    sql: str | exp.Expression,
 40    schema: t.Optional[t.Dict | Schema] = None,
 41    sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
 42    dialect: DialectType = None,
 43    **kwargs,
 44) -> Node:
 45    """Build the lineage graph for a column of a SQL query.
 46
 47    Args:
 48        column: The column to build the lineage for.
 49        sql: The SQL string or expression.
 50        schema: The schema of tables.
 51        sources: A mapping of queries which will be used to continue building lineage.
 52        dialect: The dialect of input SQL.
 53        **kwargs: Qualification optimizer kwargs.
 54
 55    Returns:
 56        A lineage node.
 57    """
 58
 59    expression = maybe_parse(sql, dialect=dialect)
 60
 61    if sources:
 62        expression = exp.expand(
 63            expression,
 64            {
 65                k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
 66                for k, v in sources.items()
 67            },
 68            dialect=dialect,
 69        )
 70
 71    qualified = qualify.qualify(
 72        expression,
 73        dialect=dialect,
 74        schema=schema,
 75        **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
 76    )
 77
 78    scope = build_scope(qualified)
 79
 80    if not scope:
 81        raise SqlglotError("Cannot build lineage, sql must be SELECT")
 82
 83    def to_node(
 84        column: str | int,
 85        scope: Scope,
 86        scope_name: t.Optional[str] = None,
 87        upstream: t.Optional[Node] = None,
 88        alias: t.Optional[str] = None,
 89    ) -> Node:
 90        aliases = {
 91            dt.alias: dt.comments[0].split()[1]
 92            for dt in scope.derived_tables
 93            if dt.comments and dt.comments[0].startswith("source: ")
 94        }
 95
 96        # Find the specific select clause that is the source of the column we want.
 97        # This can either be a specific, named select or a generic `*` clause.
 98        select = (
 99            scope.expression.selects[column]
100            if isinstance(column, int)
101            else next(
102                (select for select in scope.expression.selects if select.alias_or_name == column),
103                exp.Star() if scope.expression.is_star else None,
104            )
105        )
106
107        if not select:
108            raise ValueError(f"Could not find {column} in {scope.expression}")
109
110        if isinstance(scope.expression, exp.Union):
111            upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
112
113            index = (
114                column
115                if isinstance(column, int)
116                else next(
117                    (
118                        i
119                        for i, select in enumerate(scope.expression.selects)
120                        if select.alias_or_name == column or select.is_star
121                    ),
122                    -1,  # mypy will not allow a None here, but a negative index should never be returned
123                )
124            )
125
126            if index == -1:
127                raise ValueError(f"Could not find {column} in {scope.expression}")
128
129            for s in scope.union_scopes:
130                to_node(index, scope=s, upstream=upstream)
131
132            return upstream
133
134        if isinstance(scope.expression, exp.Select):
135            # For better ergonomics in our node labels, replace the full select with
136            # a version that has only the column we care about.
137            #   "x", SELECT x, y FROM foo
138            #     => "x", SELECT x FROM foo
139            source = t.cast(exp.Expression, scope.expression.select(select, append=False))
140        else:
141            source = scope.expression
142
143        # Create the node for this step in the lineage chain, and attach it to the previous one.
144        node = Node(
145            name=f"{scope_name}.{column}" if scope_name else str(column),
146            source=source,
147            expression=select,
148            alias=alias or "",
149        )
150
151        if upstream:
152            upstream.downstream.append(node)
153
154        subquery_scopes = {
155            id(subquery_scope.expression): subquery_scope
156            for subquery_scope in scope.subquery_scopes
157        }
158
159        for subquery in find_all_in_scope(select, exp.Subqueryable):
160            subquery_scope = subquery_scopes[id(subquery)]
161
162            for name in subquery.named_selects:
163                to_node(name, scope=subquery_scope, upstream=node)
164
165        # if the select is a star add all scope sources as downstreams
166        if select.is_star:
167            for source in scope.sources.values():
168                node.downstream.append(Node(name=select.sql(), source=source, expression=source))
169
170        # Find all columns that went into creating this one to list their lineage nodes.
171        source_columns = set(find_all_in_scope(select, exp.Column))
172
173        # If the source is a UDTF find columns used in the UTDF to generate the table
174        if isinstance(source, exp.UDTF):
175            source_columns |= set(source.find_all(exp.Column))
176
177        for c in source_columns:
178            table = c.table
179            source = scope.sources.get(table)
180
181            if isinstance(source, Scope):
182                # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
183                to_node(
184                    c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
185                )
186            else:
187                # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
188                # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
189                # is not passed into the `sources` map.
190                source = source or exp.Placeholder()
191                node.downstream.append(Node(name=c.sql(), source=source, expression=source))
192
193        return node
194
195    return to_node(column if isinstance(column, str) else column.name, scope)

Build the lineage graph for a column of a SQL query.

Arguments:
  • column: The column to build the lineage for.
  • sql: The SQL string or expression.
  • schema: The schema of tables.
  • sources: A mapping of queries which will be used to continue building lineage.
  • dialect: The dialect of input SQL.
  • **kwargs: Qualification optimizer kwargs.
Returns:

A lineage node.

class LineageHTML:
198class LineageHTML:
199    """Node to HTML generator using vis.js.
200
201    https://visjs.github.io/vis-network/docs/network/
202    """
203
204    def __init__(
205        self,
206        node: Node,
207        dialect: DialectType = None,
208        imports: bool = True,
209        **opts: t.Any,
210    ):
211        self.node = node
212        self.imports = imports
213
214        self.options = {
215            "height": "500px",
216            "width": "100%",
217            "layout": {
218                "hierarchical": {
219                    "enabled": True,
220                    "nodeSpacing": 200,
221                    "sortMethod": "directed",
222                },
223            },
224            "interaction": {
225                "dragNodes": False,
226                "selectable": False,
227            },
228            "physics": {
229                "enabled": False,
230            },
231            "edges": {
232                "arrows": "to",
233            },
234            "nodes": {
235                "font": "20px monaco",
236                "shape": "box",
237                "widthConstraint": {
238                    "maximum": 300,
239                },
240            },
241            **opts,
242        }
243
244        self.nodes = {}
245        self.edges = []
246
247        for node in node.walk():
248            if isinstance(node.expression, exp.Table):
249                label = f"FROM {node.expression.this}"
250                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
251                group = 1
252            else:
253                label = node.expression.sql(pretty=True, dialect=dialect)
254                source = node.source.transform(
255                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
256                    if n is node.expression
257                    else n,
258                    copy=False,
259                ).sql(pretty=True, dialect=dialect)
260                title = f"<pre>{source}</pre>"
261                group = 0
262
263            node_id = id(node)
264
265            self.nodes[node_id] = {
266                "id": node_id,
267                "label": label,
268                "title": title,
269                "group": group,
270            }
271
272            for d in node.downstream:
273                self.edges.append({"from": node_id, "to": id(d)})
274
275    def __str__(self):
276        nodes = json.dumps(list(self.nodes.values()))
277        edges = json.dumps(self.edges)
278        options = json.dumps(self.options)
279        imports = (
280            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
281  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
282  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
283            if self.imports
284            else ""
285        )
286
287        return f"""<div>
288  <div id="sqlglot-lineage"></div>
289  {imports}
290  <script type="text/javascript">
291    var nodes = new vis.DataSet({nodes})
292    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
293
294    new vis.Network(
295        document.getElementById("sqlglot-lineage"),
296        {{
297            nodes: nodes,
298            edges: new vis.DataSet({edges})
299        }},
300        {options},
301    )
302  </script>
303</div>"""
304
305    def _repr_html_(self) -> str:
306        return self.__str__()

Node to HTML generator using vis.js.

https://visjs.github.io/vis-network/docs/network/

LineageHTML( node: Node, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, imports: bool = True, **opts: Any)
204    def __init__(
205        self,
206        node: Node,
207        dialect: DialectType = None,
208        imports: bool = True,
209        **opts: t.Any,
210    ):
211        self.node = node
212        self.imports = imports
213
214        self.options = {
215            "height": "500px",
216            "width": "100%",
217            "layout": {
218                "hierarchical": {
219                    "enabled": True,
220                    "nodeSpacing": 200,
221                    "sortMethod": "directed",
222                },
223            },
224            "interaction": {
225                "dragNodes": False,
226                "selectable": False,
227            },
228            "physics": {
229                "enabled": False,
230            },
231            "edges": {
232                "arrows": "to",
233            },
234            "nodes": {
235                "font": "20px monaco",
236                "shape": "box",
237                "widthConstraint": {
238                    "maximum": 300,
239                },
240            },
241            **opts,
242        }
243
244        self.nodes = {}
245        self.edges = []
246
247        for node in node.walk():
248            if isinstance(node.expression, exp.Table):
249                label = f"FROM {node.expression.this}"
250                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
251                group = 1
252            else:
253                label = node.expression.sql(pretty=True, dialect=dialect)
254                source = node.source.transform(
255                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
256                    if n is node.expression
257                    else n,
258                    copy=False,
259                ).sql(pretty=True, dialect=dialect)
260                title = f"<pre>{source}</pre>"
261                group = 0
262
263            node_id = id(node)
264
265            self.nodes[node_id] = {
266                "id": node_id,
267                "label": label,
268                "title": title,
269                "group": group,
270            }
271
272            for d in node.downstream:
273                self.edges.append({"from": node_id, "to": id(d)})
node
imports
options
nodes
edges