Edit on GitHub

sqlglot.executor.table

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot.dialects.dialect import DialectType
  6from sqlglot.helper import dict_depth
  7from sqlglot.schema import AbstractMappingSchema, normalize_name
  8
  9
 10class Table:
 11    def __init__(self, columns, rows=None, column_range=None):
 12        self.columns = tuple(columns)
 13        self.column_range = column_range
 14        self.reader = RowReader(self.columns, self.column_range)
 15        self.rows = rows or []
 16        if rows:
 17            assert len(rows[0]) == len(self.columns)
 18        self.range_reader = RangeReader(self)
 19
 20    def add_columns(self, *columns: str) -> None:
 21        self.columns += columns
 22        if self.column_range:
 23            self.column_range = range(
 24                self.column_range.start, self.column_range.stop + len(columns)
 25            )
 26        self.reader = RowReader(self.columns, self.column_range)
 27
 28    def append(self, row):
 29        assert len(row) == len(self.columns)
 30        self.rows.append(row)
 31
 32    def pop(self):
 33        self.rows.pop()
 34
 35    @property
 36    def width(self):
 37        return len(self.columns)
 38
 39    def __len__(self):
 40        return len(self.rows)
 41
 42    def __iter__(self):
 43        return TableIter(self)
 44
 45    def __getitem__(self, index):
 46        self.reader.row = self.rows[index]
 47        return self.reader
 48
 49    def __repr__(self):
 50        columns = tuple(
 51            column
 52            for i, column in enumerate(self.columns)
 53            if not self.column_range or i in self.column_range
 54        )
 55        widths = {column: len(column) for column in columns}
 56        lines = [" ".join(column for column in columns)]
 57
 58        for i, row in enumerate(self):
 59            if i > 10:
 60                break
 61
 62            lines.append(
 63                " ".join(
 64                    str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns
 65                )
 66            )
 67        return "\n".join(lines)
 68
 69
 70class TableIter:
 71    def __init__(self, table):
 72        self.table = table
 73        self.index = -1
 74
 75    def __iter__(self):
 76        return self
 77
 78    def __next__(self):
 79        self.index += 1
 80        if self.index < len(self.table):
 81            return self.table[self.index]
 82        raise StopIteration
 83
 84
 85class RangeReader:
 86    def __init__(self, table):
 87        self.table = table
 88        self.range = range(0)
 89
 90    def __len__(self):
 91        return len(self.range)
 92
 93    def __getitem__(self, column):
 94        return (self.table[i][column] for i in self.range)
 95
 96
 97class RowReader:
 98    def __init__(self, columns, column_range=None):
 99        self.columns = {
100            column: i for i, column in enumerate(columns) if not column_range or i in column_range
101        }
102        self.row = None
103
104    def __getitem__(self, column):
105        return self.row[self.columns[column]]
106
107
108class Tables(AbstractMappingSchema):
109    pass
110
111
112def ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> Tables:
113    return Tables(_ensure_tables(d, dialect=dialect))
114
115
116def _ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> t.Dict:
117    if not d:
118        return {}
119
120    depth = dict_depth(d)
121    if depth > 1:
122        return {
123            normalize_name(k, dialect=dialect, is_table=True).name: _ensure_tables(
124                v, dialect=dialect
125            )
126            for k, v in d.items()
127        }
128
129    result = {}
130    for table_name, table in d.items():
131        table_name = normalize_name(table_name, dialect=dialect).name
132
133        if isinstance(table, Table):
134            result[table_name] = table
135        else:
136            table = [
137                {
138                    normalize_name(column_name, dialect=dialect).name: value
139                    for column_name, value in row.items()
140                }
141                for row in table
142            ]
143            column_names = tuple(column_name for column_name in table[0]) if table else ()
144            rows = [tuple(row[name] for name in column_names) for row in table]
145            result[table_name] = Table(columns=column_names, rows=rows)
146
147    return result
class Table:
11class Table:
12    def __init__(self, columns, rows=None, column_range=None):
13        self.columns = tuple(columns)
14        self.column_range = column_range
15        self.reader = RowReader(self.columns, self.column_range)
16        self.rows = rows or []
17        if rows:
18            assert len(rows[0]) == len(self.columns)
19        self.range_reader = RangeReader(self)
20
21    def add_columns(self, *columns: str) -> None:
22        self.columns += columns
23        if self.column_range:
24            self.column_range = range(
25                self.column_range.start, self.column_range.stop + len(columns)
26            )
27        self.reader = RowReader(self.columns, self.column_range)
28
29    def append(self, row):
30        assert len(row) == len(self.columns)
31        self.rows.append(row)
32
33    def pop(self):
34        self.rows.pop()
35
36    @property
37    def width(self):
38        return len(self.columns)
39
40    def __len__(self):
41        return len(self.rows)
42
43    def __iter__(self):
44        return TableIter(self)
45
46    def __getitem__(self, index):
47        self.reader.row = self.rows[index]
48        return self.reader
49
50    def __repr__(self):
51        columns = tuple(
52            column
53            for i, column in enumerate(self.columns)
54            if not self.column_range or i in self.column_range
55        )
56        widths = {column: len(column) for column in columns}
57        lines = [" ".join(column for column in columns)]
58
59        for i, row in enumerate(self):
60            if i > 10:
61                break
62
63            lines.append(
64                " ".join(
65                    str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns
66                )
67            )
68        return "\n".join(lines)
Table(columns, rows=None, column_range=None)
12    def __init__(self, columns, rows=None, column_range=None):
13        self.columns = tuple(columns)
14        self.column_range = column_range
15        self.reader = RowReader(self.columns, self.column_range)
16        self.rows = rows or []
17        if rows:
18            assert len(rows[0]) == len(self.columns)
19        self.range_reader = RangeReader(self)
columns
column_range
reader
rows
range_reader
def add_columns(self, *columns: str) -> None:
21    def add_columns(self, *columns: str) -> None:
22        self.columns += columns
23        if self.column_range:
24            self.column_range = range(
25                self.column_range.start, self.column_range.stop + len(columns)
26            )
27        self.reader = RowReader(self.columns, self.column_range)
def append(self, row):
29    def append(self, row):
30        assert len(row) == len(self.columns)
31        self.rows.append(row)
def pop(self):
33    def pop(self):
34        self.rows.pop()
width
36    @property
37    def width(self):
38        return len(self.columns)
class TableIter:
71class TableIter:
72    def __init__(self, table):
73        self.table = table
74        self.index = -1
75
76    def __iter__(self):
77        return self
78
79    def __next__(self):
80        self.index += 1
81        if self.index < len(self.table):
82            return self.table[self.index]
83        raise StopIteration
TableIter(table)
72    def __init__(self, table):
73        self.table = table
74        self.index = -1
table
index
class RangeReader:
86class RangeReader:
87    def __init__(self, table):
88        self.table = table
89        self.range = range(0)
90
91    def __len__(self):
92        return len(self.range)
93
94    def __getitem__(self, column):
95        return (self.table[i][column] for i in self.range)
RangeReader(table)
87    def __init__(self, table):
88        self.table = table
89        self.range = range(0)
table
range
class RowReader:
 98class RowReader:
 99    def __init__(self, columns, column_range=None):
100        self.columns = {
101            column: i for i, column in enumerate(columns) if not column_range or i in column_range
102        }
103        self.row = None
104
105    def __getitem__(self, column):
106        return self.row[self.columns[column]]
RowReader(columns, column_range=None)
 99    def __init__(self, columns, column_range=None):
100        self.columns = {
101            column: i for i, column in enumerate(columns) if not column_range or i in column_range
102        }
103        self.row = None
columns
row
class Tables(sqlglot.schema.AbstractMappingSchema):
109class Tables(AbstractMappingSchema):
110    pass
def ensure_tables( d: Optional[Dict], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> Tables:
113def ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> Tables:
114    return Tables(_ensure_tables(d, dialect=dialect))