sqlglot.executor.table
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot.dialects.dialect import DialectType 6from sqlglot.helper import dict_depth 7from sqlglot.schema import AbstractMappingSchema, normalize_name 8 9 10class Table: 11 def __init__(self, columns, rows=None, column_range=None): 12 self.columns = tuple(columns) 13 self.column_range = column_range 14 self.reader = RowReader(self.columns, self.column_range) 15 self.rows = rows or [] 16 if rows: 17 assert len(rows[0]) == len(self.columns) 18 self.range_reader = RangeReader(self) 19 20 def add_columns(self, *columns: str) -> None: 21 self.columns += columns 22 if self.column_range: 23 self.column_range = range( 24 self.column_range.start, self.column_range.stop + len(columns) 25 ) 26 self.reader = RowReader(self.columns, self.column_range) 27 28 def append(self, row): 29 assert len(row) == len(self.columns) 30 self.rows.append(row) 31 32 def pop(self): 33 self.rows.pop() 34 35 @property 36 def width(self): 37 return len(self.columns) 38 39 def __len__(self): 40 return len(self.rows) 41 42 def __iter__(self): 43 return TableIter(self) 44 45 def __getitem__(self, index): 46 self.reader.row = self.rows[index] 47 return self.reader 48 49 def __repr__(self): 50 columns = tuple( 51 column 52 for i, column in enumerate(self.columns) 53 if not self.column_range or i in self.column_range 54 ) 55 widths = {column: len(column) for column in columns} 56 lines = [" ".join(column for column in columns)] 57 58 for i, row in enumerate(self): 59 if i > 10: 60 break 61 62 lines.append( 63 " ".join( 64 str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns 65 ) 66 ) 67 return "\n".join(lines) 68 69 70class TableIter: 71 def __init__(self, table): 72 self.table = table 73 self.index = -1 74 75 def __iter__(self): 76 return self 77 78 def __next__(self): 79 self.index += 1 80 if self.index < len(self.table): 81 return self.table[self.index] 82 raise StopIteration 83 84 85class RangeReader: 86 def __init__(self, table): 87 self.table = table 88 self.range = range(0) 89 90 def __len__(self): 91 return len(self.range) 92 93 def __getitem__(self, column): 94 return (self.table[i][column] for i in self.range) 95 96 97class RowReader: 98 def __init__(self, columns, column_range=None): 99 self.columns = { 100 column: i for i, column in enumerate(columns) if not column_range or i in column_range 101 } 102 self.row = None 103 104 def __getitem__(self, column): 105 return self.row[self.columns[column]] 106 107 108class Tables(AbstractMappingSchema): 109 pass 110 111 112def ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> Tables: 113 return Tables(_ensure_tables(d, dialect=dialect)) 114 115 116def _ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> t.Dict: 117 if not d: 118 return {} 119 120 depth = dict_depth(d) 121 if depth > 1: 122 return { 123 normalize_name(k, dialect=dialect, is_table=True).name: _ensure_tables( 124 v, dialect=dialect 125 ) 126 for k, v in d.items() 127 } 128 129 result = {} 130 for table_name, table in d.items(): 131 table_name = normalize_name(table_name, dialect=dialect).name 132 133 if isinstance(table, Table): 134 result[table_name] = table 135 else: 136 table = [ 137 { 138 normalize_name(column_name, dialect=dialect).name: value 139 for column_name, value in row.items() 140 } 141 for row in table 142 ] 143 column_names = tuple(column_name for column_name in table[0]) if table else () 144 rows = [tuple(row[name] for name in column_names) for row in table] 145 result[table_name] = Table(columns=column_names, rows=rows) 146 147 return result
class
Table:
11class Table: 12 def __init__(self, columns, rows=None, column_range=None): 13 self.columns = tuple(columns) 14 self.column_range = column_range 15 self.reader = RowReader(self.columns, self.column_range) 16 self.rows = rows or [] 17 if rows: 18 assert len(rows[0]) == len(self.columns) 19 self.range_reader = RangeReader(self) 20 21 def add_columns(self, *columns: str) -> None: 22 self.columns += columns 23 if self.column_range: 24 self.column_range = range( 25 self.column_range.start, self.column_range.stop + len(columns) 26 ) 27 self.reader = RowReader(self.columns, self.column_range) 28 29 def append(self, row): 30 assert len(row) == len(self.columns) 31 self.rows.append(row) 32 33 def pop(self): 34 self.rows.pop() 35 36 @property 37 def width(self): 38 return len(self.columns) 39 40 def __len__(self): 41 return len(self.rows) 42 43 def __iter__(self): 44 return TableIter(self) 45 46 def __getitem__(self, index): 47 self.reader.row = self.rows[index] 48 return self.reader 49 50 def __repr__(self): 51 columns = tuple( 52 column 53 for i, column in enumerate(self.columns) 54 if not self.column_range or i in self.column_range 55 ) 56 widths = {column: len(column) for column in columns} 57 lines = [" ".join(column for column in columns)] 58 59 for i, row in enumerate(self): 60 if i > 10: 61 break 62 63 lines.append( 64 " ".join( 65 str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns 66 ) 67 ) 68 return "\n".join(lines)
Table(columns, rows=None, column_range=None)
12 def __init__(self, columns, rows=None, column_range=None): 13 self.columns = tuple(columns) 14 self.column_range = column_range 15 self.reader = RowReader(self.columns, self.column_range) 16 self.rows = rows or [] 17 if rows: 18 assert len(rows[0]) == len(self.columns) 19 self.range_reader = RangeReader(self)
class
TableIter:
class
RangeReader:
class
RowReader:
98class RowReader: 99 def __init__(self, columns, column_range=None): 100 self.columns = { 101 column: i for i, column in enumerate(columns) if not column_range or i in column_range 102 } 103 self.row = None 104 105 def __getitem__(self, column): 106 return self.row[self.columns[column]]
def
ensure_tables( d: Optional[Dict], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> Tables: