sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.dialects.dialect import Dialect 8from sqlglot.errors import SchemaError 9from sqlglot.helper import dict_depth, first 10from sqlglot.trie import TrieResult, in_trie, new_trie 11 12if t.TYPE_CHECKING: 13 from sqlglot.dialects.dialect import DialectType 14 15 ColumnMapping = t.Union[t.Dict, str, t.List] 16 17 18class Schema(abc.ABC): 19 """Abstract base class for database schemas""" 20 21 dialect: DialectType 22 23 @abc.abstractmethod 24 def add_table( 25 self, 26 table: exp.Table | str, 27 column_mapping: t.Optional[ColumnMapping] = None, 28 dialect: DialectType = None, 29 normalize: t.Optional[bool] = None, 30 match_depth: bool = True, 31 ) -> None: 32 """ 33 Register or update a table. Some implementing classes may require column information to also be provided. 34 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 35 36 Args: 37 table: the `Table` expression instance or string representing the table. 38 column_mapping: a column mapping that describes the structure of the table. 39 dialect: the SQL dialect that will be used to parse `table` if it's a string. 40 normalize: whether to normalize identifiers according to the dialect of interest. 41 match_depth: whether to enforce that the table must match the schema's depth or not. 42 """ 43 44 @abc.abstractmethod 45 def column_names( 46 self, 47 table: exp.Table | str, 48 only_visible: bool = False, 49 dialect: DialectType = None, 50 normalize: t.Optional[bool] = None, 51 ) -> t.Sequence[str]: 52 """ 53 Get the column names for a table. 54 55 Args: 56 table: the `Table` expression instance. 57 only_visible: whether to include invisible columns. 58 dialect: the SQL dialect that will be used to parse `table` if it's a string. 59 normalize: whether to normalize identifiers according to the dialect of interest. 60 61 Returns: 62 The sequence of column names. 63 """ 64 65 @abc.abstractmethod 66 def get_column_type( 67 self, 68 table: exp.Table | str, 69 column: exp.Column | str, 70 dialect: DialectType = None, 71 normalize: t.Optional[bool] = None, 72 ) -> exp.DataType: 73 """ 74 Get the `sqlglot.exp.DataType` type of a column in the schema. 75 76 Args: 77 table: the source table. 78 column: the target column. 79 dialect: the SQL dialect that will be used to parse `table` if it's a string. 80 normalize: whether to normalize identifiers according to the dialect of interest. 81 82 Returns: 83 The resulting column type. 84 """ 85 86 def has_column( 87 self, 88 table: exp.Table | str, 89 column: exp.Column | str, 90 dialect: DialectType = None, 91 normalize: t.Optional[bool] = None, 92 ) -> bool: 93 """ 94 Returns whether `column` appears in `table`'s schema. 95 96 Args: 97 table: the source table. 98 column: the target column. 99 dialect: the SQL dialect that will be used to parse `table` if it's a string. 100 normalize: whether to normalize identifiers according to the dialect of interest. 101 102 Returns: 103 True if the column appears in the schema, False otherwise. 104 """ 105 name = column if isinstance(column, str) else column.name 106 return name in self.column_names(table, dialect=dialect, normalize=normalize) 107 108 @property 109 @abc.abstractmethod 110 def supported_table_args(self) -> t.Tuple[str, ...]: 111 """ 112 Table arguments this schema support, e.g. `("this", "db", "catalog")` 113 """ 114 115 @property 116 def empty(self) -> bool: 117 """Returns whether the schema is empty.""" 118 return True 119 120 121class AbstractMappingSchema: 122 def __init__( 123 self, 124 mapping: t.Optional[t.Dict] = None, 125 ) -> None: 126 self.mapping = mapping or {} 127 self.mapping_trie = new_trie( 128 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 129 ) 130 self._supported_table_args: t.Tuple[str, ...] = tuple() 131 132 @property 133 def empty(self) -> bool: 134 return not self.mapping 135 136 def depth(self) -> int: 137 return dict_depth(self.mapping) 138 139 @property 140 def supported_table_args(self) -> t.Tuple[str, ...]: 141 if not self._supported_table_args and self.mapping: 142 depth = self.depth() 143 144 if not depth: # None 145 self._supported_table_args = tuple() 146 elif 1 <= depth <= 3: 147 self._supported_table_args = exp.TABLE_PARTS[:depth] 148 else: 149 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 150 151 return self._supported_table_args 152 153 def table_parts(self, table: exp.Table) -> t.List[str]: 154 return [part.name for part in reversed(table.parts)] 155 156 def find( 157 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 158 ) -> t.Optional[t.Any]: 159 """ 160 Returns the schema of a given table. 161 162 Args: 163 table: the target table. 164 raise_on_missing: whether to raise in case the schema is not found. 165 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 166 167 Returns: 168 The schema of the target table. 169 """ 170 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 171 value, trie = in_trie(self.mapping_trie, parts) 172 173 if value == TrieResult.FAILED: 174 return None 175 176 if value == TrieResult.PREFIX: 177 possibilities = flatten_schema(trie) 178 179 if len(possibilities) == 1: 180 parts.extend(possibilities[0]) 181 else: 182 message = ", ".join(".".join(parts) for parts in possibilities) 183 if raise_on_missing: 184 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 185 return None 186 187 return self.nested_get(parts, raise_on_missing=raise_on_missing) 188 189 def nested_get( 190 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 191 ) -> t.Optional[t.Any]: 192 return nested_get( 193 d or self.mapping, 194 *zip(self.supported_table_args, reversed(parts)), 195 raise_on_missing=raise_on_missing, 196 ) 197 198 199class MappingSchema(AbstractMappingSchema, Schema): 200 """ 201 Schema based on a nested mapping. 202 203 Args: 204 schema: Mapping in one of the following forms: 205 1. {table: {col: type}} 206 2. {db: {table: {col: type}}} 207 3. {catalog: {db: {table: {col: type}}}} 208 4. None - Tables will be added later 209 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 210 are assumed to be visible. The nesting should mirror that of the schema: 211 1. {table: set(*cols)}} 212 2. {db: {table: set(*cols)}}} 213 3. {catalog: {db: {table: set(*cols)}}}} 214 dialect: The dialect to be used for custom type mappings & parsing string arguments. 215 normalize: Whether to normalize identifier names according to the given dialect or not. 216 """ 217 218 def __init__( 219 self, 220 schema: t.Optional[t.Dict] = None, 221 visible: t.Optional[t.Dict] = None, 222 dialect: DialectType = None, 223 normalize: bool = True, 224 ) -> None: 225 self.dialect = dialect 226 self.visible = {} if visible is None else visible 227 self.normalize = normalize 228 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 229 self._depth = 0 230 schema = {} if schema is None else schema 231 232 super().__init__(self._normalize(schema) if self.normalize else schema) 233 234 @classmethod 235 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 236 return MappingSchema( 237 schema=mapping_schema.mapping, 238 visible=mapping_schema.visible, 239 dialect=mapping_schema.dialect, 240 normalize=mapping_schema.normalize, 241 ) 242 243 def find( 244 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 245 ) -> t.Optional[t.Any]: 246 schema = super().find( 247 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 248 ) 249 if ensure_data_types and isinstance(schema, dict): 250 schema = { 251 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 252 for col, dtype in schema.items() 253 } 254 255 return schema 256 257 def copy(self, **kwargs) -> MappingSchema: 258 return MappingSchema( 259 **{ # type: ignore 260 "schema": self.mapping.copy(), 261 "visible": self.visible.copy(), 262 "dialect": self.dialect, 263 "normalize": self.normalize, 264 **kwargs, 265 } 266 ) 267 268 def add_table( 269 self, 270 table: exp.Table | str, 271 column_mapping: t.Optional[ColumnMapping] = None, 272 dialect: DialectType = None, 273 normalize: t.Optional[bool] = None, 274 match_depth: bool = True, 275 ) -> None: 276 """ 277 Register or update a table. Updates are only performed if a new column mapping is provided. 278 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 279 280 Args: 281 table: the `Table` expression instance or string representing the table. 282 column_mapping: a column mapping that describes the structure of the table. 283 dialect: the SQL dialect that will be used to parse `table` if it's a string. 284 normalize: whether to normalize identifiers according to the dialect of interest. 285 match_depth: whether to enforce that the table must match the schema's depth or not. 286 """ 287 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 288 289 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 290 raise SchemaError( 291 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 292 f"schema's nesting level: {self.depth()}." 293 ) 294 295 normalized_column_mapping = { 296 self._normalize_name(key, dialect=dialect, normalize=normalize): value 297 for key, value in ensure_column_mapping(column_mapping).items() 298 } 299 300 schema = self.find(normalized_table, raise_on_missing=False) 301 if schema and not normalized_column_mapping: 302 return 303 304 parts = self.table_parts(normalized_table) 305 306 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 307 new_trie([parts], self.mapping_trie) 308 309 def column_names( 310 self, 311 table: exp.Table | str, 312 only_visible: bool = False, 313 dialect: DialectType = None, 314 normalize: t.Optional[bool] = None, 315 ) -> t.List[str]: 316 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 317 318 schema = self.find(normalized_table) 319 if schema is None: 320 return [] 321 322 if not only_visible or not self.visible: 323 return list(schema) 324 325 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 326 return [col for col in schema if col in visible] 327 328 def get_column_type( 329 self, 330 table: exp.Table | str, 331 column: exp.Column | str, 332 dialect: DialectType = None, 333 normalize: t.Optional[bool] = None, 334 ) -> exp.DataType: 335 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 336 337 normalized_column_name = self._normalize_name( 338 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 339 ) 340 341 table_schema = self.find(normalized_table, raise_on_missing=False) 342 if table_schema: 343 column_type = table_schema.get(normalized_column_name) 344 345 if isinstance(column_type, exp.DataType): 346 return column_type 347 elif isinstance(column_type, str): 348 return self._to_data_type(column_type, dialect=dialect) 349 350 return exp.DataType.build("unknown") 351 352 def has_column( 353 self, 354 table: exp.Table | str, 355 column: exp.Column | str, 356 dialect: DialectType = None, 357 normalize: t.Optional[bool] = None, 358 ) -> bool: 359 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 360 361 normalized_column_name = self._normalize_name( 362 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 363 ) 364 365 table_schema = self.find(normalized_table, raise_on_missing=False) 366 return normalized_column_name in table_schema if table_schema else False 367 368 def _normalize(self, schema: t.Dict) -> t.Dict: 369 """ 370 Normalizes all identifiers in the schema. 371 372 Args: 373 schema: the schema to normalize. 374 375 Returns: 376 The normalized schema mapping. 377 """ 378 normalized_mapping: t.Dict = {} 379 flattened_schema = flatten_schema(schema) 380 error_msg = "Table {} must match the schema's nesting level: {}." 381 382 for keys in flattened_schema: 383 columns = nested_get(schema, *zip(keys, keys)) 384 385 if not isinstance(columns, dict): 386 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 387 if not columns: 388 raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column") 389 if isinstance(first(columns.values()), dict): 390 raise SchemaError( 391 error_msg.format( 392 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 393 ), 394 ) 395 396 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 397 for column_name, column_type in columns.items(): 398 nested_set( 399 normalized_mapping, 400 normalized_keys + [self._normalize_name(column_name)], 401 column_type, 402 ) 403 404 return normalized_mapping 405 406 def _normalize_table( 407 self, 408 table: exp.Table | str, 409 dialect: DialectType = None, 410 normalize: t.Optional[bool] = None, 411 ) -> exp.Table: 412 dialect = dialect or self.dialect 413 normalize = self.normalize if normalize is None else normalize 414 415 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 416 417 if normalize: 418 for part in normalized_table.parts: 419 if isinstance(part, exp.Identifier): 420 part.replace( 421 normalize_name(part, dialect=dialect, is_table=True, normalize=normalize) 422 ) 423 424 return normalized_table 425 426 def _normalize_name( 427 self, 428 name: str | exp.Identifier, 429 dialect: DialectType = None, 430 is_table: bool = False, 431 normalize: t.Optional[bool] = None, 432 ) -> str: 433 return normalize_name( 434 name, 435 dialect=dialect or self.dialect, 436 is_table=is_table, 437 normalize=self.normalize if normalize is None else normalize, 438 ).name 439 440 def depth(self) -> int: 441 if not self.empty and not self._depth: 442 # The columns themselves are a mapping, but we don't want to include those 443 self._depth = super().depth() - 1 444 return self._depth 445 446 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 447 """ 448 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 449 450 Args: 451 schema_type: the type we want to convert. 452 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 453 454 Returns: 455 The resulting expression type. 456 """ 457 if schema_type not in self._type_mapping_cache: 458 dialect = dialect or self.dialect 459 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 460 461 try: 462 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 463 self._type_mapping_cache[schema_type] = expression 464 except AttributeError: 465 in_dialect = f" in dialect {dialect}" if dialect else "" 466 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 467 468 return self._type_mapping_cache[schema_type] 469 470 471def normalize_name( 472 identifier: str | exp.Identifier, 473 dialect: DialectType = None, 474 is_table: bool = False, 475 normalize: t.Optional[bool] = True, 476) -> exp.Identifier: 477 if isinstance(identifier, str): 478 identifier = exp.parse_identifier(identifier, dialect=dialect) 479 480 if not normalize: 481 return identifier 482 483 # this is used for normalize_identifier, bigquery has special rules pertaining tables 484 identifier.meta["is_table"] = is_table 485 return Dialect.get_or_raise(dialect).normalize_identifier(identifier) 486 487 488def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 489 if isinstance(schema, Schema): 490 return schema 491 492 return MappingSchema(schema, **kwargs) 493 494 495def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 496 if mapping is None: 497 return {} 498 elif isinstance(mapping, dict): 499 return mapping 500 elif isinstance(mapping, str): 501 col_name_type_strs = [x.strip() for x in mapping.split(",")] 502 return { 503 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 504 for name_type_str in col_name_type_strs 505 } 506 elif isinstance(mapping, list): 507 return {x.strip(): None for x in mapping} 508 509 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 510 511 512def flatten_schema( 513 schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None 514) -> t.List[t.List[str]]: 515 tables = [] 516 keys = keys or [] 517 depth = dict_depth(schema) - 1 if depth is None else depth 518 519 for k, v in schema.items(): 520 if depth == 1 or not isinstance(v, dict): 521 tables.append(keys + [k]) 522 elif depth >= 2: 523 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 524 525 return tables 526 527 528def nested_get( 529 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 530) -> t.Optional[t.Any]: 531 """ 532 Get a value for a nested dictionary. 533 534 Args: 535 d: the dictionary to search. 536 *path: tuples of (name, key), where: 537 `key` is the key in the dictionary to get. 538 `name` is a string to use in the error if `key` isn't found. 539 540 Returns: 541 The value or None if it doesn't exist. 542 """ 543 for name, key in path: 544 d = d.get(key) # type: ignore 545 if d is None: 546 if raise_on_missing: 547 name = "table" if name == "this" else name 548 raise ValueError(f"Unknown {name}: {key}") 549 return None 550 551 return d 552 553 554def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 555 """ 556 In-place set a value for a nested dictionary 557 558 Example: 559 >>> nested_set({}, ["top_key", "second_key"], "value") 560 {'top_key': {'second_key': 'value'}} 561 562 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 563 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 564 565 Args: 566 d: dictionary to update. 567 keys: the keys that makeup the path to `value`. 568 value: the value to set in the dictionary for the given key path. 569 570 Returns: 571 The (possibly) updated dictionary. 572 """ 573 if not keys: 574 return d 575 576 if len(keys) == 1: 577 d[keys[0]] = value 578 return d 579 580 subd = d 581 for key in keys[:-1]: 582 if key not in subd: 583 subd = subd.setdefault(key, {}) 584 else: 585 subd = subd[key] 586 587 subd[keys[-1]] = value 588 return d
19class Schema(abc.ABC): 20 """Abstract base class for database schemas""" 21 22 dialect: DialectType 23 24 @abc.abstractmethod 25 def add_table( 26 self, 27 table: exp.Table | str, 28 column_mapping: t.Optional[ColumnMapping] = None, 29 dialect: DialectType = None, 30 normalize: t.Optional[bool] = None, 31 match_depth: bool = True, 32 ) -> None: 33 """ 34 Register or update a table. Some implementing classes may require column information to also be provided. 35 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 36 37 Args: 38 table: the `Table` expression instance or string representing the table. 39 column_mapping: a column mapping that describes the structure of the table. 40 dialect: the SQL dialect that will be used to parse `table` if it's a string. 41 normalize: whether to normalize identifiers according to the dialect of interest. 42 match_depth: whether to enforce that the table must match the schema's depth or not. 43 """ 44 45 @abc.abstractmethod 46 def column_names( 47 self, 48 table: exp.Table | str, 49 only_visible: bool = False, 50 dialect: DialectType = None, 51 normalize: t.Optional[bool] = None, 52 ) -> t.Sequence[str]: 53 """ 54 Get the column names for a table. 55 56 Args: 57 table: the `Table` expression instance. 58 only_visible: whether to include invisible columns. 59 dialect: the SQL dialect that will be used to parse `table` if it's a string. 60 normalize: whether to normalize identifiers according to the dialect of interest. 61 62 Returns: 63 The sequence of column names. 64 """ 65 66 @abc.abstractmethod 67 def get_column_type( 68 self, 69 table: exp.Table | str, 70 column: exp.Column | str, 71 dialect: DialectType = None, 72 normalize: t.Optional[bool] = None, 73 ) -> exp.DataType: 74 """ 75 Get the `sqlglot.exp.DataType` type of a column in the schema. 76 77 Args: 78 table: the source table. 79 column: the target column. 80 dialect: the SQL dialect that will be used to parse `table` if it's a string. 81 normalize: whether to normalize identifiers according to the dialect of interest. 82 83 Returns: 84 The resulting column type. 85 """ 86 87 def has_column( 88 self, 89 table: exp.Table | str, 90 column: exp.Column | str, 91 dialect: DialectType = None, 92 normalize: t.Optional[bool] = None, 93 ) -> bool: 94 """ 95 Returns whether `column` appears in `table`'s schema. 96 97 Args: 98 table: the source table. 99 column: the target column. 100 dialect: the SQL dialect that will be used to parse `table` if it's a string. 101 normalize: whether to normalize identifiers according to the dialect of interest. 102 103 Returns: 104 True if the column appears in the schema, False otherwise. 105 """ 106 name = column if isinstance(column, str) else column.name 107 return name in self.column_names(table, dialect=dialect, normalize=normalize) 108 109 @property 110 @abc.abstractmethod 111 def supported_table_args(self) -> t.Tuple[str, ...]: 112 """ 113 Table arguments this schema support, e.g. `("this", "db", "catalog")` 114 """ 115 116 @property 117 def empty(self) -> bool: 118 """Returns whether the schema is empty.""" 119 return True
Abstract base class for database schemas
24 @abc.abstractmethod 25 def add_table( 26 self, 27 table: exp.Table | str, 28 column_mapping: t.Optional[ColumnMapping] = None, 29 dialect: DialectType = None, 30 normalize: t.Optional[bool] = None, 31 match_depth: bool = True, 32 ) -> None: 33 """ 34 Register or update a table. Some implementing classes may require column information to also be provided. 35 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 36 37 Args: 38 table: the `Table` expression instance or string representing the table. 39 column_mapping: a column mapping that describes the structure of the table. 40 dialect: the SQL dialect that will be used to parse `table` if it's a string. 41 normalize: whether to normalize identifiers according to the dialect of interest. 42 match_depth: whether to enforce that the table must match the schema's depth or not. 43 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
45 @abc.abstractmethod 46 def column_names( 47 self, 48 table: exp.Table | str, 49 only_visible: bool = False, 50 dialect: DialectType = None, 51 normalize: t.Optional[bool] = None, 52 ) -> t.Sequence[str]: 53 """ 54 Get the column names for a table. 55 56 Args: 57 table: the `Table` expression instance. 58 only_visible: whether to include invisible columns. 59 dialect: the SQL dialect that will be used to parse `table` if it's a string. 60 normalize: whether to normalize identifiers according to the dialect of interest. 61 62 Returns: 63 The sequence of column names. 64 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
66 @abc.abstractmethod 67 def get_column_type( 68 self, 69 table: exp.Table | str, 70 column: exp.Column | str, 71 dialect: DialectType = None, 72 normalize: t.Optional[bool] = None, 73 ) -> exp.DataType: 74 """ 75 Get the `sqlglot.exp.DataType` type of a column in the schema. 76 77 Args: 78 table: the source table. 79 column: the target column. 80 dialect: the SQL dialect that will be used to parse `table` if it's a string. 81 normalize: whether to normalize identifiers according to the dialect of interest. 82 83 Returns: 84 The resulting column type. 85 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
87 def has_column( 88 self, 89 table: exp.Table | str, 90 column: exp.Column | str, 91 dialect: DialectType = None, 92 normalize: t.Optional[bool] = None, 93 ) -> bool: 94 """ 95 Returns whether `column` appears in `table`'s schema. 96 97 Args: 98 table: the source table. 99 column: the target column. 100 dialect: the SQL dialect that will be used to parse `table` if it's a string. 101 normalize: whether to normalize identifiers according to the dialect of interest. 102 103 Returns: 104 True if the column appears in the schema, False otherwise. 105 """ 106 name = column if isinstance(column, str) else column.name 107 return name in self.column_names(table, dialect=dialect, normalize=normalize)
Returns whether column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
122class AbstractMappingSchema: 123 def __init__( 124 self, 125 mapping: t.Optional[t.Dict] = None, 126 ) -> None: 127 self.mapping = mapping or {} 128 self.mapping_trie = new_trie( 129 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 130 ) 131 self._supported_table_args: t.Tuple[str, ...] = tuple() 132 133 @property 134 def empty(self) -> bool: 135 return not self.mapping 136 137 def depth(self) -> int: 138 return dict_depth(self.mapping) 139 140 @property 141 def supported_table_args(self) -> t.Tuple[str, ...]: 142 if not self._supported_table_args and self.mapping: 143 depth = self.depth() 144 145 if not depth: # None 146 self._supported_table_args = tuple() 147 elif 1 <= depth <= 3: 148 self._supported_table_args = exp.TABLE_PARTS[:depth] 149 else: 150 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 151 152 return self._supported_table_args 153 154 def table_parts(self, table: exp.Table) -> t.List[str]: 155 return [part.name for part in reversed(table.parts)] 156 157 def find( 158 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 159 ) -> t.Optional[t.Any]: 160 """ 161 Returns the schema of a given table. 162 163 Args: 164 table: the target table. 165 raise_on_missing: whether to raise in case the schema is not found. 166 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 167 168 Returns: 169 The schema of the target table. 170 """ 171 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 172 value, trie = in_trie(self.mapping_trie, parts) 173 174 if value == TrieResult.FAILED: 175 return None 176 177 if value == TrieResult.PREFIX: 178 possibilities = flatten_schema(trie) 179 180 if len(possibilities) == 1: 181 parts.extend(possibilities[0]) 182 else: 183 message = ", ".join(".".join(parts) for parts in possibilities) 184 if raise_on_missing: 185 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 186 return None 187 188 return self.nested_get(parts, raise_on_missing=raise_on_missing) 189 190 def nested_get( 191 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 192 ) -> t.Optional[t.Any]: 193 return nested_get( 194 d or self.mapping, 195 *zip(self.supported_table_args, reversed(parts)), 196 raise_on_missing=raise_on_missing, 197 )
140 @property 141 def supported_table_args(self) -> t.Tuple[str, ...]: 142 if not self._supported_table_args and self.mapping: 143 depth = self.depth() 144 145 if not depth: # None 146 self._supported_table_args = tuple() 147 elif 1 <= depth <= 3: 148 self._supported_table_args = exp.TABLE_PARTS[:depth] 149 else: 150 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 151 152 return self._supported_table_args
157 def find( 158 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 159 ) -> t.Optional[t.Any]: 160 """ 161 Returns the schema of a given table. 162 163 Args: 164 table: the target table. 165 raise_on_missing: whether to raise in case the schema is not found. 166 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 167 168 Returns: 169 The schema of the target table. 170 """ 171 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 172 value, trie = in_trie(self.mapping_trie, parts) 173 174 if value == TrieResult.FAILED: 175 return None 176 177 if value == TrieResult.PREFIX: 178 possibilities = flatten_schema(trie) 179 180 if len(possibilities) == 1: 181 parts.extend(possibilities[0]) 182 else: 183 message = ", ".join(".".join(parts) for parts in possibilities) 184 if raise_on_missing: 185 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 186 return None 187 188 return self.nested_get(parts, raise_on_missing=raise_on_missing)
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether to raise in case the schema is not found.
- ensure_data_types: whether to convert
str
types to theirDataType
equivalents.
Returns:
The schema of the target table.
200class MappingSchema(AbstractMappingSchema, Schema): 201 """ 202 Schema based on a nested mapping. 203 204 Args: 205 schema: Mapping in one of the following forms: 206 1. {table: {col: type}} 207 2. {db: {table: {col: type}}} 208 3. {catalog: {db: {table: {col: type}}}} 209 4. None - Tables will be added later 210 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 211 are assumed to be visible. The nesting should mirror that of the schema: 212 1. {table: set(*cols)}} 213 2. {db: {table: set(*cols)}}} 214 3. {catalog: {db: {table: set(*cols)}}}} 215 dialect: The dialect to be used for custom type mappings & parsing string arguments. 216 normalize: Whether to normalize identifier names according to the given dialect or not. 217 """ 218 219 def __init__( 220 self, 221 schema: t.Optional[t.Dict] = None, 222 visible: t.Optional[t.Dict] = None, 223 dialect: DialectType = None, 224 normalize: bool = True, 225 ) -> None: 226 self.dialect = dialect 227 self.visible = {} if visible is None else visible 228 self.normalize = normalize 229 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 230 self._depth = 0 231 schema = {} if schema is None else schema 232 233 super().__init__(self._normalize(schema) if self.normalize else schema) 234 235 @classmethod 236 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 237 return MappingSchema( 238 schema=mapping_schema.mapping, 239 visible=mapping_schema.visible, 240 dialect=mapping_schema.dialect, 241 normalize=mapping_schema.normalize, 242 ) 243 244 def find( 245 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 246 ) -> t.Optional[t.Any]: 247 schema = super().find( 248 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 249 ) 250 if ensure_data_types and isinstance(schema, dict): 251 schema = { 252 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 253 for col, dtype in schema.items() 254 } 255 256 return schema 257 258 def copy(self, **kwargs) -> MappingSchema: 259 return MappingSchema( 260 **{ # type: ignore 261 "schema": self.mapping.copy(), 262 "visible": self.visible.copy(), 263 "dialect": self.dialect, 264 "normalize": self.normalize, 265 **kwargs, 266 } 267 ) 268 269 def add_table( 270 self, 271 table: exp.Table | str, 272 column_mapping: t.Optional[ColumnMapping] = None, 273 dialect: DialectType = None, 274 normalize: t.Optional[bool] = None, 275 match_depth: bool = True, 276 ) -> None: 277 """ 278 Register or update a table. Updates are only performed if a new column mapping is provided. 279 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 280 281 Args: 282 table: the `Table` expression instance or string representing the table. 283 column_mapping: a column mapping that describes the structure of the table. 284 dialect: the SQL dialect that will be used to parse `table` if it's a string. 285 normalize: whether to normalize identifiers according to the dialect of interest. 286 match_depth: whether to enforce that the table must match the schema's depth or not. 287 """ 288 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 289 290 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 291 raise SchemaError( 292 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 293 f"schema's nesting level: {self.depth()}." 294 ) 295 296 normalized_column_mapping = { 297 self._normalize_name(key, dialect=dialect, normalize=normalize): value 298 for key, value in ensure_column_mapping(column_mapping).items() 299 } 300 301 schema = self.find(normalized_table, raise_on_missing=False) 302 if schema and not normalized_column_mapping: 303 return 304 305 parts = self.table_parts(normalized_table) 306 307 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 308 new_trie([parts], self.mapping_trie) 309 310 def column_names( 311 self, 312 table: exp.Table | str, 313 only_visible: bool = False, 314 dialect: DialectType = None, 315 normalize: t.Optional[bool] = None, 316 ) -> t.List[str]: 317 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 318 319 schema = self.find(normalized_table) 320 if schema is None: 321 return [] 322 323 if not only_visible or not self.visible: 324 return list(schema) 325 326 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 327 return [col for col in schema if col in visible] 328 329 def get_column_type( 330 self, 331 table: exp.Table | str, 332 column: exp.Column | str, 333 dialect: DialectType = None, 334 normalize: t.Optional[bool] = None, 335 ) -> exp.DataType: 336 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 337 338 normalized_column_name = self._normalize_name( 339 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 340 ) 341 342 table_schema = self.find(normalized_table, raise_on_missing=False) 343 if table_schema: 344 column_type = table_schema.get(normalized_column_name) 345 346 if isinstance(column_type, exp.DataType): 347 return column_type 348 elif isinstance(column_type, str): 349 return self._to_data_type(column_type, dialect=dialect) 350 351 return exp.DataType.build("unknown") 352 353 def has_column( 354 self, 355 table: exp.Table | str, 356 column: exp.Column | str, 357 dialect: DialectType = None, 358 normalize: t.Optional[bool] = None, 359 ) -> bool: 360 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 361 362 normalized_column_name = self._normalize_name( 363 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 364 ) 365 366 table_schema = self.find(normalized_table, raise_on_missing=False) 367 return normalized_column_name in table_schema if table_schema else False 368 369 def _normalize(self, schema: t.Dict) -> t.Dict: 370 """ 371 Normalizes all identifiers in the schema. 372 373 Args: 374 schema: the schema to normalize. 375 376 Returns: 377 The normalized schema mapping. 378 """ 379 normalized_mapping: t.Dict = {} 380 flattened_schema = flatten_schema(schema) 381 error_msg = "Table {} must match the schema's nesting level: {}." 382 383 for keys in flattened_schema: 384 columns = nested_get(schema, *zip(keys, keys)) 385 386 if not isinstance(columns, dict): 387 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 388 if not columns: 389 raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column") 390 if isinstance(first(columns.values()), dict): 391 raise SchemaError( 392 error_msg.format( 393 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 394 ), 395 ) 396 397 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 398 for column_name, column_type in columns.items(): 399 nested_set( 400 normalized_mapping, 401 normalized_keys + [self._normalize_name(column_name)], 402 column_type, 403 ) 404 405 return normalized_mapping 406 407 def _normalize_table( 408 self, 409 table: exp.Table | str, 410 dialect: DialectType = None, 411 normalize: t.Optional[bool] = None, 412 ) -> exp.Table: 413 dialect = dialect or self.dialect 414 normalize = self.normalize if normalize is None else normalize 415 416 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 417 418 if normalize: 419 for part in normalized_table.parts: 420 if isinstance(part, exp.Identifier): 421 part.replace( 422 normalize_name(part, dialect=dialect, is_table=True, normalize=normalize) 423 ) 424 425 return normalized_table 426 427 def _normalize_name( 428 self, 429 name: str | exp.Identifier, 430 dialect: DialectType = None, 431 is_table: bool = False, 432 normalize: t.Optional[bool] = None, 433 ) -> str: 434 return normalize_name( 435 name, 436 dialect=dialect or self.dialect, 437 is_table=is_table, 438 normalize=self.normalize if normalize is None else normalize, 439 ).name 440 441 def depth(self) -> int: 442 if not self.empty and not self._depth: 443 # The columns themselves are a mapping, but we don't want to include those 444 self._depth = super().depth() - 1 445 return self._depth 446 447 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 448 """ 449 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 450 451 Args: 452 schema_type: the type we want to convert. 453 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 454 455 Returns: 456 The resulting expression type. 457 """ 458 if schema_type not in self._type_mapping_cache: 459 dialect = dialect or self.dialect 460 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 461 462 try: 463 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 464 self._type_mapping_cache[schema_type] = expression 465 except AttributeError: 466 in_dialect = f" in dialect {dialect}" if dialect else "" 467 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 468 469 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
219 def __init__( 220 self, 221 schema: t.Optional[t.Dict] = None, 222 visible: t.Optional[t.Dict] = None, 223 dialect: DialectType = None, 224 normalize: bool = True, 225 ) -> None: 226 self.dialect = dialect 227 self.visible = {} if visible is None else visible 228 self.normalize = normalize 229 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 230 self._depth = 0 231 schema = {} if schema is None else schema 232 233 super().__init__(self._normalize(schema) if self.normalize else schema)
244 def find( 245 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 246 ) -> t.Optional[t.Any]: 247 schema = super().find( 248 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 249 ) 250 if ensure_data_types and isinstance(schema, dict): 251 schema = { 252 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 253 for col, dtype in schema.items() 254 } 255 256 return schema
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether to raise in case the schema is not found.
- ensure_data_types: whether to convert
str
types to theirDataType
equivalents.
Returns:
The schema of the target table.
269 def add_table( 270 self, 271 table: exp.Table | str, 272 column_mapping: t.Optional[ColumnMapping] = None, 273 dialect: DialectType = None, 274 normalize: t.Optional[bool] = None, 275 match_depth: bool = True, 276 ) -> None: 277 """ 278 Register or update a table. Updates are only performed if a new column mapping is provided. 279 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 280 281 Args: 282 table: the `Table` expression instance or string representing the table. 283 column_mapping: a column mapping that describes the structure of the table. 284 dialect: the SQL dialect that will be used to parse `table` if it's a string. 285 normalize: whether to normalize identifiers according to the dialect of interest. 286 match_depth: whether to enforce that the table must match the schema's depth or not. 287 """ 288 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 289 290 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 291 raise SchemaError( 292 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 293 f"schema's nesting level: {self.depth()}." 294 ) 295 296 normalized_column_mapping = { 297 self._normalize_name(key, dialect=dialect, normalize=normalize): value 298 for key, value in ensure_column_mapping(column_mapping).items() 299 } 300 301 schema = self.find(normalized_table, raise_on_missing=False) 302 if schema and not normalized_column_mapping: 303 return 304 305 parts = self.table_parts(normalized_table) 306 307 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 308 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
310 def column_names( 311 self, 312 table: exp.Table | str, 313 only_visible: bool = False, 314 dialect: DialectType = None, 315 normalize: t.Optional[bool] = None, 316 ) -> t.List[str]: 317 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 318 319 schema = self.find(normalized_table) 320 if schema is None: 321 return [] 322 323 if not only_visible or not self.visible: 324 return list(schema) 325 326 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 327 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
329 def get_column_type( 330 self, 331 table: exp.Table | str, 332 column: exp.Column | str, 333 dialect: DialectType = None, 334 normalize: t.Optional[bool] = None, 335 ) -> exp.DataType: 336 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 337 338 normalized_column_name = self._normalize_name( 339 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 340 ) 341 342 table_schema = self.find(normalized_table, raise_on_missing=False) 343 if table_schema: 344 column_type = table_schema.get(normalized_column_name) 345 346 if isinstance(column_type, exp.DataType): 347 return column_type 348 elif isinstance(column_type, str): 349 return self._to_data_type(column_type, dialect=dialect) 350 351 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
353 def has_column( 354 self, 355 table: exp.Table | str, 356 column: exp.Column | str, 357 dialect: DialectType = None, 358 normalize: t.Optional[bool] = None, 359 ) -> bool: 360 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 361 362 normalized_column_name = self._normalize_name( 363 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 364 ) 365 366 table_schema = self.find(normalized_table, raise_on_missing=False) 367 return normalized_column_name in table_schema if table_schema else False
Returns whether column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
Inherited Members
472def normalize_name( 473 identifier: str | exp.Identifier, 474 dialect: DialectType = None, 475 is_table: bool = False, 476 normalize: t.Optional[bool] = True, 477) -> exp.Identifier: 478 if isinstance(identifier, str): 479 identifier = exp.parse_identifier(identifier, dialect=dialect) 480 481 if not normalize: 482 return identifier 483 484 # this is used for normalize_identifier, bigquery has special rules pertaining tables 485 identifier.meta["is_table"] = is_table 486 return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
496def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 497 if mapping is None: 498 return {} 499 elif isinstance(mapping, dict): 500 return mapping 501 elif isinstance(mapping, str): 502 col_name_type_strs = [x.strip() for x in mapping.split(",")] 503 return { 504 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 505 for name_type_str in col_name_type_strs 506 } 507 elif isinstance(mapping, list): 508 return {x.strip(): None for x in mapping} 509 510 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
513def flatten_schema( 514 schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None 515) -> t.List[t.List[str]]: 516 tables = [] 517 keys = keys or [] 518 depth = dict_depth(schema) - 1 if depth is None else depth 519 520 for k, v in schema.items(): 521 if depth == 1 or not isinstance(v, dict): 522 tables.append(keys + [k]) 523 elif depth >= 2: 524 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 525 526 return tables
529def nested_get( 530 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 531) -> t.Optional[t.Any]: 532 """ 533 Get a value for a nested dictionary. 534 535 Args: 536 d: the dictionary to search. 537 *path: tuples of (name, key), where: 538 `key` is the key in the dictionary to get. 539 `name` is a string to use in the error if `key` isn't found. 540 541 Returns: 542 The value or None if it doesn't exist. 543 """ 544 for name, key in path: 545 d = d.get(key) # type: ignore 546 if d is None: 547 if raise_on_missing: 548 name = "table" if name == "this" else name 549 raise ValueError(f"Unknown {name}: {key}") 550 return None 551 552 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
555def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 556 """ 557 In-place set a value for a nested dictionary 558 559 Example: 560 >>> nested_set({}, ["top_key", "second_key"], "value") 561 {'top_key': {'second_key': 'value'}} 562 563 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 564 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 565 566 Args: 567 d: dictionary to update. 568 keys: the keys that makeup the path to `value`. 569 value: the value to set in the dictionary for the given key path. 570 571 Returns: 572 The (possibly) updated dictionary. 573 """ 574 if not keys: 575 return d 576 577 if len(keys) == 1: 578 d[keys[0]] = value 579 return d 580 581 subd = d 582 for key in keys[:-1]: 583 if key not in subd: 584 subd = subd.setdefault(key, {}) 585 else: 586 subd = subd[key] 587 588 subd[keys[-1]] = value 589 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.