sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.dialects.dialect import Dialect 8from sqlglot.errors import SchemaError 9from sqlglot.helper import dict_depth, first 10from sqlglot.trie import TrieResult, in_trie, new_trie 11 12if t.TYPE_CHECKING: 13 from sqlglot.dialects.dialect import DialectType 14 15 ColumnMapping = t.Union[t.Dict, str, t.List] 16 17 18class Schema(abc.ABC): 19 """Abstract base class for database schemas""" 20 21 dialect: DialectType 22 23 @abc.abstractmethod 24 def add_table( 25 self, 26 table: exp.Table | str, 27 column_mapping: t.Optional[ColumnMapping] = None, 28 dialect: DialectType = None, 29 normalize: t.Optional[bool] = None, 30 match_depth: bool = True, 31 ) -> None: 32 """ 33 Register or update a table. Some implementing classes may require column information to also be provided. 34 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 35 36 Args: 37 table: the `Table` expression instance or string representing the table. 38 column_mapping: a column mapping that describes the structure of the table. 39 dialect: the SQL dialect that will be used to parse `table` if it's a string. 40 normalize: whether to normalize identifiers according to the dialect of interest. 41 match_depth: whether to enforce that the table must match the schema's depth or not. 42 """ 43 44 @abc.abstractmethod 45 def column_names( 46 self, 47 table: exp.Table | str, 48 only_visible: bool = False, 49 dialect: DialectType = None, 50 normalize: t.Optional[bool] = None, 51 ) -> t.Sequence[str]: 52 """ 53 Get the column names for a table. 54 55 Args: 56 table: the `Table` expression instance. 57 only_visible: whether to include invisible columns. 58 dialect: the SQL dialect that will be used to parse `table` if it's a string. 59 normalize: whether to normalize identifiers according to the dialect of interest. 60 61 Returns: 62 The sequence of column names. 63 """ 64 65 @abc.abstractmethod 66 def get_column_type( 67 self, 68 table: exp.Table | str, 69 column: exp.Column | str, 70 dialect: DialectType = None, 71 normalize: t.Optional[bool] = None, 72 ) -> exp.DataType: 73 """ 74 Get the `sqlglot.exp.DataType` type of a column in the schema. 75 76 Args: 77 table: the source table. 78 column: the target column. 79 dialect: the SQL dialect that will be used to parse `table` if it's a string. 80 normalize: whether to normalize identifiers according to the dialect of interest. 81 82 Returns: 83 The resulting column type. 84 """ 85 86 def has_column( 87 self, 88 table: exp.Table | str, 89 column: exp.Column | str, 90 dialect: DialectType = None, 91 normalize: t.Optional[bool] = None, 92 ) -> bool: 93 """ 94 Returns whether `column` appears in `table`'s schema. 95 96 Args: 97 table: the source table. 98 column: the target column. 99 dialect: the SQL dialect that will be used to parse `table` if it's a string. 100 normalize: whether to normalize identifiers according to the dialect of interest. 101 102 Returns: 103 True if the column appears in the schema, False otherwise. 104 """ 105 name = column if isinstance(column, str) else column.name 106 return name in self.column_names(table, dialect=dialect, normalize=normalize) 107 108 @property 109 @abc.abstractmethod 110 def supported_table_args(self) -> t.Tuple[str, ...]: 111 """ 112 Table arguments this schema support, e.g. `("this", "db", "catalog")` 113 """ 114 115 @property 116 def empty(self) -> bool: 117 """Returns whether the schema is empty.""" 118 return True 119 120 121class AbstractMappingSchema: 122 def __init__( 123 self, 124 mapping: t.Optional[t.Dict] = None, 125 ) -> None: 126 self.mapping = mapping or {} 127 self.mapping_trie = new_trie( 128 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 129 ) 130 self._supported_table_args: t.Tuple[str, ...] = tuple() 131 132 @property 133 def empty(self) -> bool: 134 return not self.mapping 135 136 def depth(self) -> int: 137 return dict_depth(self.mapping) 138 139 @property 140 def supported_table_args(self) -> t.Tuple[str, ...]: 141 if not self._supported_table_args and self.mapping: 142 depth = self.depth() 143 144 if not depth: # None 145 self._supported_table_args = tuple() 146 elif 1 <= depth <= 3: 147 self._supported_table_args = exp.TABLE_PARTS[:depth] 148 else: 149 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 150 151 return self._supported_table_args 152 153 def table_parts(self, table: exp.Table) -> t.List[str]: 154 if isinstance(table.this, exp.ReadCSV): 155 return [table.this.name] 156 return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] 157 158 def find( 159 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 160 ) -> t.Optional[t.Any]: 161 """ 162 Returns the schema of a given table. 163 164 Args: 165 table: the target table. 166 raise_on_missing: whether to raise in case the schema is not found. 167 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 168 169 Returns: 170 The schema of the target table. 171 """ 172 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 173 value, trie = in_trie(self.mapping_trie, parts) 174 175 if value == TrieResult.FAILED: 176 return None 177 178 if value == TrieResult.PREFIX: 179 possibilities = flatten_schema(trie) 180 181 if len(possibilities) == 1: 182 parts.extend(possibilities[0]) 183 else: 184 message = ", ".join(".".join(parts) for parts in possibilities) 185 if raise_on_missing: 186 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 187 return None 188 189 return self.nested_get(parts, raise_on_missing=raise_on_missing) 190 191 def nested_get( 192 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 193 ) -> t.Optional[t.Any]: 194 return nested_get( 195 d or self.mapping, 196 *zip(self.supported_table_args, reversed(parts)), 197 raise_on_missing=raise_on_missing, 198 ) 199 200 201class MappingSchema(AbstractMappingSchema, Schema): 202 """ 203 Schema based on a nested mapping. 204 205 Args: 206 schema: Mapping in one of the following forms: 207 1. {table: {col: type}} 208 2. {db: {table: {col: type}}} 209 3. {catalog: {db: {table: {col: type}}}} 210 4. None - Tables will be added later 211 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 212 are assumed to be visible. The nesting should mirror that of the schema: 213 1. {table: set(*cols)}} 214 2. {db: {table: set(*cols)}}} 215 3. {catalog: {db: {table: set(*cols)}}}} 216 dialect: The dialect to be used for custom type mappings & parsing string arguments. 217 normalize: Whether to normalize identifier names according to the given dialect or not. 218 """ 219 220 def __init__( 221 self, 222 schema: t.Optional[t.Dict] = None, 223 visible: t.Optional[t.Dict] = None, 224 dialect: DialectType = None, 225 normalize: bool = True, 226 ) -> None: 227 self.dialect = dialect 228 self.visible = {} if visible is None else visible 229 self.normalize = normalize 230 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 231 self._depth = 0 232 schema = {} if schema is None else schema 233 234 super().__init__(self._normalize(schema) if self.normalize else schema) 235 236 @classmethod 237 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 238 return MappingSchema( 239 schema=mapping_schema.mapping, 240 visible=mapping_schema.visible, 241 dialect=mapping_schema.dialect, 242 normalize=mapping_schema.normalize, 243 ) 244 245 def find( 246 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 247 ) -> t.Optional[t.Any]: 248 schema = super().find( 249 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 250 ) 251 if ensure_data_types and isinstance(schema, dict): 252 schema = { 253 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 254 for col, dtype in schema.items() 255 } 256 257 return schema 258 259 def copy(self, **kwargs) -> MappingSchema: 260 return MappingSchema( 261 **{ # type: ignore 262 "schema": self.mapping.copy(), 263 "visible": self.visible.copy(), 264 "dialect": self.dialect, 265 "normalize": self.normalize, 266 **kwargs, 267 } 268 ) 269 270 def add_table( 271 self, 272 table: exp.Table | str, 273 column_mapping: t.Optional[ColumnMapping] = None, 274 dialect: DialectType = None, 275 normalize: t.Optional[bool] = None, 276 match_depth: bool = True, 277 ) -> None: 278 """ 279 Register or update a table. Updates are only performed if a new column mapping is provided. 280 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 281 282 Args: 283 table: the `Table` expression instance or string representing the table. 284 column_mapping: a column mapping that describes the structure of the table. 285 dialect: the SQL dialect that will be used to parse `table` if it's a string. 286 normalize: whether to normalize identifiers according to the dialect of interest. 287 match_depth: whether to enforce that the table must match the schema's depth or not. 288 """ 289 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 290 291 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 292 raise SchemaError( 293 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 294 f"schema's nesting level: {self.depth()}." 295 ) 296 297 normalized_column_mapping = { 298 self._normalize_name(key, dialect=dialect, normalize=normalize): value 299 for key, value in ensure_column_mapping(column_mapping).items() 300 } 301 302 schema = self.find(normalized_table, raise_on_missing=False) 303 if schema and not normalized_column_mapping: 304 return 305 306 parts = self.table_parts(normalized_table) 307 308 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 309 new_trie([parts], self.mapping_trie) 310 311 def column_names( 312 self, 313 table: exp.Table | str, 314 only_visible: bool = False, 315 dialect: DialectType = None, 316 normalize: t.Optional[bool] = None, 317 ) -> t.List[str]: 318 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 319 320 schema = self.find(normalized_table) 321 if schema is None: 322 return [] 323 324 if not only_visible or not self.visible: 325 return list(schema) 326 327 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 328 return [col for col in schema if col in visible] 329 330 def get_column_type( 331 self, 332 table: exp.Table | str, 333 column: exp.Column | str, 334 dialect: DialectType = None, 335 normalize: t.Optional[bool] = None, 336 ) -> exp.DataType: 337 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 338 339 normalized_column_name = self._normalize_name( 340 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 341 ) 342 343 table_schema = self.find(normalized_table, raise_on_missing=False) 344 if table_schema: 345 column_type = table_schema.get(normalized_column_name) 346 347 if isinstance(column_type, exp.DataType): 348 return column_type 349 elif isinstance(column_type, str): 350 return self._to_data_type(column_type, dialect=dialect) 351 352 return exp.DataType.build("unknown") 353 354 def has_column( 355 self, 356 table: exp.Table | str, 357 column: exp.Column | str, 358 dialect: DialectType = None, 359 normalize: t.Optional[bool] = None, 360 ) -> bool: 361 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 362 363 normalized_column_name = self._normalize_name( 364 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 365 ) 366 367 table_schema = self.find(normalized_table, raise_on_missing=False) 368 return normalized_column_name in table_schema if table_schema else False 369 370 def _normalize(self, schema: t.Dict) -> t.Dict: 371 """ 372 Normalizes all identifiers in the schema. 373 374 Args: 375 schema: the schema to normalize. 376 377 Returns: 378 The normalized schema mapping. 379 """ 380 normalized_mapping: t.Dict = {} 381 flattened_schema = flatten_schema(schema) 382 error_msg = "Table {} must match the schema's nesting level: {}." 383 384 for keys in flattened_schema: 385 columns = nested_get(schema, *zip(keys, keys)) 386 387 if not isinstance(columns, dict): 388 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 389 if not columns: 390 raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column") 391 if isinstance(first(columns.values()), dict): 392 raise SchemaError( 393 error_msg.format( 394 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 395 ), 396 ) 397 398 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 399 for column_name, column_type in columns.items(): 400 nested_set( 401 normalized_mapping, 402 normalized_keys + [self._normalize_name(column_name)], 403 column_type, 404 ) 405 406 return normalized_mapping 407 408 def _normalize_table( 409 self, 410 table: exp.Table | str, 411 dialect: DialectType = None, 412 normalize: t.Optional[bool] = None, 413 ) -> exp.Table: 414 dialect = dialect or self.dialect 415 normalize = self.normalize if normalize is None else normalize 416 417 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 418 419 if normalize: 420 for arg in exp.TABLE_PARTS: 421 value = normalized_table.args.get(arg) 422 if isinstance(value, exp.Identifier): 423 normalized_table.set( 424 arg, 425 normalize_name(value, dialect=dialect, is_table=True, normalize=normalize), 426 ) 427 428 return normalized_table 429 430 def _normalize_name( 431 self, 432 name: str | exp.Identifier, 433 dialect: DialectType = None, 434 is_table: bool = False, 435 normalize: t.Optional[bool] = None, 436 ) -> str: 437 return normalize_name( 438 name, 439 dialect=dialect or self.dialect, 440 is_table=is_table, 441 normalize=self.normalize if normalize is None else normalize, 442 ).name 443 444 def depth(self) -> int: 445 if not self.empty and not self._depth: 446 # The columns themselves are a mapping, but we don't want to include those 447 self._depth = super().depth() - 1 448 return self._depth 449 450 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 451 """ 452 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 453 454 Args: 455 schema_type: the type we want to convert. 456 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 457 458 Returns: 459 The resulting expression type. 460 """ 461 if schema_type not in self._type_mapping_cache: 462 dialect = dialect or self.dialect 463 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 464 465 try: 466 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 467 self._type_mapping_cache[schema_type] = expression 468 except AttributeError: 469 in_dialect = f" in dialect {dialect}" if dialect else "" 470 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 471 472 return self._type_mapping_cache[schema_type] 473 474 475def normalize_name( 476 identifier: str | exp.Identifier, 477 dialect: DialectType = None, 478 is_table: bool = False, 479 normalize: t.Optional[bool] = True, 480) -> exp.Identifier: 481 if isinstance(identifier, str): 482 identifier = exp.parse_identifier(identifier, dialect=dialect) 483 484 if not normalize: 485 return identifier 486 487 # this is used for normalize_identifier, bigquery has special rules pertaining tables 488 identifier.meta["is_table"] = is_table 489 return Dialect.get_or_raise(dialect).normalize_identifier(identifier) 490 491 492def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 493 if isinstance(schema, Schema): 494 return schema 495 496 return MappingSchema(schema, **kwargs) 497 498 499def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 500 if mapping is None: 501 return {} 502 elif isinstance(mapping, dict): 503 return mapping 504 elif isinstance(mapping, str): 505 col_name_type_strs = [x.strip() for x in mapping.split(",")] 506 return { 507 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 508 for name_type_str in col_name_type_strs 509 } 510 elif isinstance(mapping, list): 511 return {x.strip(): None for x in mapping} 512 513 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 514 515 516def flatten_schema( 517 schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None 518) -> t.List[t.List[str]]: 519 tables = [] 520 keys = keys or [] 521 depth = dict_depth(schema) - 1 if depth is None else depth 522 523 for k, v in schema.items(): 524 if depth == 1 or not isinstance(v, dict): 525 tables.append(keys + [k]) 526 elif depth >= 2: 527 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 528 529 return tables 530 531 532def nested_get( 533 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 534) -> t.Optional[t.Any]: 535 """ 536 Get a value for a nested dictionary. 537 538 Args: 539 d: the dictionary to search. 540 *path: tuples of (name, key), where: 541 `key` is the key in the dictionary to get. 542 `name` is a string to use in the error if `key` isn't found. 543 544 Returns: 545 The value or None if it doesn't exist. 546 """ 547 for name, key in path: 548 d = d.get(key) # type: ignore 549 if d is None: 550 if raise_on_missing: 551 name = "table" if name == "this" else name 552 raise ValueError(f"Unknown {name}: {key}") 553 return None 554 555 return d 556 557 558def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 559 """ 560 In-place set a value for a nested dictionary 561 562 Example: 563 >>> nested_set({}, ["top_key", "second_key"], "value") 564 {'top_key': {'second_key': 'value'}} 565 566 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 567 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 568 569 Args: 570 d: dictionary to update. 571 keys: the keys that makeup the path to `value`. 572 value: the value to set in the dictionary for the given key path. 573 574 Returns: 575 The (possibly) updated dictionary. 576 """ 577 if not keys: 578 return d 579 580 if len(keys) == 1: 581 d[keys[0]] = value 582 return d 583 584 subd = d 585 for key in keys[:-1]: 586 if key not in subd: 587 subd = subd.setdefault(key, {}) 588 else: 589 subd = subd[key] 590 591 subd[keys[-1]] = value 592 return d
19class Schema(abc.ABC): 20 """Abstract base class for database schemas""" 21 22 dialect: DialectType 23 24 @abc.abstractmethod 25 def add_table( 26 self, 27 table: exp.Table | str, 28 column_mapping: t.Optional[ColumnMapping] = None, 29 dialect: DialectType = None, 30 normalize: t.Optional[bool] = None, 31 match_depth: bool = True, 32 ) -> None: 33 """ 34 Register or update a table. Some implementing classes may require column information to also be provided. 35 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 36 37 Args: 38 table: the `Table` expression instance or string representing the table. 39 column_mapping: a column mapping that describes the structure of the table. 40 dialect: the SQL dialect that will be used to parse `table` if it's a string. 41 normalize: whether to normalize identifiers according to the dialect of interest. 42 match_depth: whether to enforce that the table must match the schema's depth or not. 43 """ 44 45 @abc.abstractmethod 46 def column_names( 47 self, 48 table: exp.Table | str, 49 only_visible: bool = False, 50 dialect: DialectType = None, 51 normalize: t.Optional[bool] = None, 52 ) -> t.Sequence[str]: 53 """ 54 Get the column names for a table. 55 56 Args: 57 table: the `Table` expression instance. 58 only_visible: whether to include invisible columns. 59 dialect: the SQL dialect that will be used to parse `table` if it's a string. 60 normalize: whether to normalize identifiers according to the dialect of interest. 61 62 Returns: 63 The sequence of column names. 64 """ 65 66 @abc.abstractmethod 67 def get_column_type( 68 self, 69 table: exp.Table | str, 70 column: exp.Column | str, 71 dialect: DialectType = None, 72 normalize: t.Optional[bool] = None, 73 ) -> exp.DataType: 74 """ 75 Get the `sqlglot.exp.DataType` type of a column in the schema. 76 77 Args: 78 table: the source table. 79 column: the target column. 80 dialect: the SQL dialect that will be used to parse `table` if it's a string. 81 normalize: whether to normalize identifiers according to the dialect of interest. 82 83 Returns: 84 The resulting column type. 85 """ 86 87 def has_column( 88 self, 89 table: exp.Table | str, 90 column: exp.Column | str, 91 dialect: DialectType = None, 92 normalize: t.Optional[bool] = None, 93 ) -> bool: 94 """ 95 Returns whether `column` appears in `table`'s schema. 96 97 Args: 98 table: the source table. 99 column: the target column. 100 dialect: the SQL dialect that will be used to parse `table` if it's a string. 101 normalize: whether to normalize identifiers according to the dialect of interest. 102 103 Returns: 104 True if the column appears in the schema, False otherwise. 105 """ 106 name = column if isinstance(column, str) else column.name 107 return name in self.column_names(table, dialect=dialect, normalize=normalize) 108 109 @property 110 @abc.abstractmethod 111 def supported_table_args(self) -> t.Tuple[str, ...]: 112 """ 113 Table arguments this schema support, e.g. `("this", "db", "catalog")` 114 """ 115 116 @property 117 def empty(self) -> bool: 118 """Returns whether the schema is empty.""" 119 return True
Abstract base class for database schemas
24 @abc.abstractmethod 25 def add_table( 26 self, 27 table: exp.Table | str, 28 column_mapping: t.Optional[ColumnMapping] = None, 29 dialect: DialectType = None, 30 normalize: t.Optional[bool] = None, 31 match_depth: bool = True, 32 ) -> None: 33 """ 34 Register or update a table. Some implementing classes may require column information to also be provided. 35 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 36 37 Args: 38 table: the `Table` expression instance or string representing the table. 39 column_mapping: a column mapping that describes the structure of the table. 40 dialect: the SQL dialect that will be used to parse `table` if it's a string. 41 normalize: whether to normalize identifiers according to the dialect of interest. 42 match_depth: whether to enforce that the table must match the schema's depth or not. 43 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
45 @abc.abstractmethod 46 def column_names( 47 self, 48 table: exp.Table | str, 49 only_visible: bool = False, 50 dialect: DialectType = None, 51 normalize: t.Optional[bool] = None, 52 ) -> t.Sequence[str]: 53 """ 54 Get the column names for a table. 55 56 Args: 57 table: the `Table` expression instance. 58 only_visible: whether to include invisible columns. 59 dialect: the SQL dialect that will be used to parse `table` if it's a string. 60 normalize: whether to normalize identifiers according to the dialect of interest. 61 62 Returns: 63 The sequence of column names. 64 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
66 @abc.abstractmethod 67 def get_column_type( 68 self, 69 table: exp.Table | str, 70 column: exp.Column | str, 71 dialect: DialectType = None, 72 normalize: t.Optional[bool] = None, 73 ) -> exp.DataType: 74 """ 75 Get the `sqlglot.exp.DataType` type of a column in the schema. 76 77 Args: 78 table: the source table. 79 column: the target column. 80 dialect: the SQL dialect that will be used to parse `table` if it's a string. 81 normalize: whether to normalize identifiers according to the dialect of interest. 82 83 Returns: 84 The resulting column type. 85 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
87 def has_column( 88 self, 89 table: exp.Table | str, 90 column: exp.Column | str, 91 dialect: DialectType = None, 92 normalize: t.Optional[bool] = None, 93 ) -> bool: 94 """ 95 Returns whether `column` appears in `table`'s schema. 96 97 Args: 98 table: the source table. 99 column: the target column. 100 dialect: the SQL dialect that will be used to parse `table` if it's a string. 101 normalize: whether to normalize identifiers according to the dialect of interest. 102 103 Returns: 104 True if the column appears in the schema, False otherwise. 105 """ 106 name = column if isinstance(column, str) else column.name 107 return name in self.column_names(table, dialect=dialect, normalize=normalize)
Returns whether column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
122class AbstractMappingSchema: 123 def __init__( 124 self, 125 mapping: t.Optional[t.Dict] = None, 126 ) -> None: 127 self.mapping = mapping or {} 128 self.mapping_trie = new_trie( 129 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 130 ) 131 self._supported_table_args: t.Tuple[str, ...] = tuple() 132 133 @property 134 def empty(self) -> bool: 135 return not self.mapping 136 137 def depth(self) -> int: 138 return dict_depth(self.mapping) 139 140 @property 141 def supported_table_args(self) -> t.Tuple[str, ...]: 142 if not self._supported_table_args and self.mapping: 143 depth = self.depth() 144 145 if not depth: # None 146 self._supported_table_args = tuple() 147 elif 1 <= depth <= 3: 148 self._supported_table_args = exp.TABLE_PARTS[:depth] 149 else: 150 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 151 152 return self._supported_table_args 153 154 def table_parts(self, table: exp.Table) -> t.List[str]: 155 if isinstance(table.this, exp.ReadCSV): 156 return [table.this.name] 157 return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] 158 159 def find( 160 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 161 ) -> t.Optional[t.Any]: 162 """ 163 Returns the schema of a given table. 164 165 Args: 166 table: the target table. 167 raise_on_missing: whether to raise in case the schema is not found. 168 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 169 170 Returns: 171 The schema of the target table. 172 """ 173 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 174 value, trie = in_trie(self.mapping_trie, parts) 175 176 if value == TrieResult.FAILED: 177 return None 178 179 if value == TrieResult.PREFIX: 180 possibilities = flatten_schema(trie) 181 182 if len(possibilities) == 1: 183 parts.extend(possibilities[0]) 184 else: 185 message = ", ".join(".".join(parts) for parts in possibilities) 186 if raise_on_missing: 187 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 188 return None 189 190 return self.nested_get(parts, raise_on_missing=raise_on_missing) 191 192 def nested_get( 193 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 194 ) -> t.Optional[t.Any]: 195 return nested_get( 196 d or self.mapping, 197 *zip(self.supported_table_args, reversed(parts)), 198 raise_on_missing=raise_on_missing, 199 )
140 @property 141 def supported_table_args(self) -> t.Tuple[str, ...]: 142 if not self._supported_table_args and self.mapping: 143 depth = self.depth() 144 145 if not depth: # None 146 self._supported_table_args = tuple() 147 elif 1 <= depth <= 3: 148 self._supported_table_args = exp.TABLE_PARTS[:depth] 149 else: 150 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 151 152 return self._supported_table_args
159 def find( 160 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 161 ) -> t.Optional[t.Any]: 162 """ 163 Returns the schema of a given table. 164 165 Args: 166 table: the target table. 167 raise_on_missing: whether to raise in case the schema is not found. 168 ensure_data_types: whether to convert `str` types to their `DataType` equivalents. 169 170 Returns: 171 The schema of the target table. 172 """ 173 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 174 value, trie = in_trie(self.mapping_trie, parts) 175 176 if value == TrieResult.FAILED: 177 return None 178 179 if value == TrieResult.PREFIX: 180 possibilities = flatten_schema(trie) 181 182 if len(possibilities) == 1: 183 parts.extend(possibilities[0]) 184 else: 185 message = ", ".join(".".join(parts) for parts in possibilities) 186 if raise_on_missing: 187 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 188 return None 189 190 return self.nested_get(parts, raise_on_missing=raise_on_missing)
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether to raise in case the schema is not found.
- ensure_data_types: whether to convert
str
types to theirDataType
equivalents.
Returns:
The schema of the target table.
202class MappingSchema(AbstractMappingSchema, Schema): 203 """ 204 Schema based on a nested mapping. 205 206 Args: 207 schema: Mapping in one of the following forms: 208 1. {table: {col: type}} 209 2. {db: {table: {col: type}}} 210 3. {catalog: {db: {table: {col: type}}}} 211 4. None - Tables will be added later 212 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 213 are assumed to be visible. The nesting should mirror that of the schema: 214 1. {table: set(*cols)}} 215 2. {db: {table: set(*cols)}}} 216 3. {catalog: {db: {table: set(*cols)}}}} 217 dialect: The dialect to be used for custom type mappings & parsing string arguments. 218 normalize: Whether to normalize identifier names according to the given dialect or not. 219 """ 220 221 def __init__( 222 self, 223 schema: t.Optional[t.Dict] = None, 224 visible: t.Optional[t.Dict] = None, 225 dialect: DialectType = None, 226 normalize: bool = True, 227 ) -> None: 228 self.dialect = dialect 229 self.visible = {} if visible is None else visible 230 self.normalize = normalize 231 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 232 self._depth = 0 233 schema = {} if schema is None else schema 234 235 super().__init__(self._normalize(schema) if self.normalize else schema) 236 237 @classmethod 238 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 239 return MappingSchema( 240 schema=mapping_schema.mapping, 241 visible=mapping_schema.visible, 242 dialect=mapping_schema.dialect, 243 normalize=mapping_schema.normalize, 244 ) 245 246 def find( 247 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 248 ) -> t.Optional[t.Any]: 249 schema = super().find( 250 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 251 ) 252 if ensure_data_types and isinstance(schema, dict): 253 schema = { 254 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 255 for col, dtype in schema.items() 256 } 257 258 return schema 259 260 def copy(self, **kwargs) -> MappingSchema: 261 return MappingSchema( 262 **{ # type: ignore 263 "schema": self.mapping.copy(), 264 "visible": self.visible.copy(), 265 "dialect": self.dialect, 266 "normalize": self.normalize, 267 **kwargs, 268 } 269 ) 270 271 def add_table( 272 self, 273 table: exp.Table | str, 274 column_mapping: t.Optional[ColumnMapping] = None, 275 dialect: DialectType = None, 276 normalize: t.Optional[bool] = None, 277 match_depth: bool = True, 278 ) -> None: 279 """ 280 Register or update a table. Updates are only performed if a new column mapping is provided. 281 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 282 283 Args: 284 table: the `Table` expression instance or string representing the table. 285 column_mapping: a column mapping that describes the structure of the table. 286 dialect: the SQL dialect that will be used to parse `table` if it's a string. 287 normalize: whether to normalize identifiers according to the dialect of interest. 288 match_depth: whether to enforce that the table must match the schema's depth or not. 289 """ 290 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 291 292 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 293 raise SchemaError( 294 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 295 f"schema's nesting level: {self.depth()}." 296 ) 297 298 normalized_column_mapping = { 299 self._normalize_name(key, dialect=dialect, normalize=normalize): value 300 for key, value in ensure_column_mapping(column_mapping).items() 301 } 302 303 schema = self.find(normalized_table, raise_on_missing=False) 304 if schema and not normalized_column_mapping: 305 return 306 307 parts = self.table_parts(normalized_table) 308 309 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 310 new_trie([parts], self.mapping_trie) 311 312 def column_names( 313 self, 314 table: exp.Table | str, 315 only_visible: bool = False, 316 dialect: DialectType = None, 317 normalize: t.Optional[bool] = None, 318 ) -> t.List[str]: 319 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 320 321 schema = self.find(normalized_table) 322 if schema is None: 323 return [] 324 325 if not only_visible or not self.visible: 326 return list(schema) 327 328 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 329 return [col for col in schema if col in visible] 330 331 def get_column_type( 332 self, 333 table: exp.Table | str, 334 column: exp.Column | str, 335 dialect: DialectType = None, 336 normalize: t.Optional[bool] = None, 337 ) -> exp.DataType: 338 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 339 340 normalized_column_name = self._normalize_name( 341 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 342 ) 343 344 table_schema = self.find(normalized_table, raise_on_missing=False) 345 if table_schema: 346 column_type = table_schema.get(normalized_column_name) 347 348 if isinstance(column_type, exp.DataType): 349 return column_type 350 elif isinstance(column_type, str): 351 return self._to_data_type(column_type, dialect=dialect) 352 353 return exp.DataType.build("unknown") 354 355 def has_column( 356 self, 357 table: exp.Table | str, 358 column: exp.Column | str, 359 dialect: DialectType = None, 360 normalize: t.Optional[bool] = None, 361 ) -> bool: 362 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 363 364 normalized_column_name = self._normalize_name( 365 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 366 ) 367 368 table_schema = self.find(normalized_table, raise_on_missing=False) 369 return normalized_column_name in table_schema if table_schema else False 370 371 def _normalize(self, schema: t.Dict) -> t.Dict: 372 """ 373 Normalizes all identifiers in the schema. 374 375 Args: 376 schema: the schema to normalize. 377 378 Returns: 379 The normalized schema mapping. 380 """ 381 normalized_mapping: t.Dict = {} 382 flattened_schema = flatten_schema(schema) 383 error_msg = "Table {} must match the schema's nesting level: {}." 384 385 for keys in flattened_schema: 386 columns = nested_get(schema, *zip(keys, keys)) 387 388 if not isinstance(columns, dict): 389 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 390 if not columns: 391 raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column") 392 if isinstance(first(columns.values()), dict): 393 raise SchemaError( 394 error_msg.format( 395 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 396 ), 397 ) 398 399 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 400 for column_name, column_type in columns.items(): 401 nested_set( 402 normalized_mapping, 403 normalized_keys + [self._normalize_name(column_name)], 404 column_type, 405 ) 406 407 return normalized_mapping 408 409 def _normalize_table( 410 self, 411 table: exp.Table | str, 412 dialect: DialectType = None, 413 normalize: t.Optional[bool] = None, 414 ) -> exp.Table: 415 dialect = dialect or self.dialect 416 normalize = self.normalize if normalize is None else normalize 417 418 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 419 420 if normalize: 421 for arg in exp.TABLE_PARTS: 422 value = normalized_table.args.get(arg) 423 if isinstance(value, exp.Identifier): 424 normalized_table.set( 425 arg, 426 normalize_name(value, dialect=dialect, is_table=True, normalize=normalize), 427 ) 428 429 return normalized_table 430 431 def _normalize_name( 432 self, 433 name: str | exp.Identifier, 434 dialect: DialectType = None, 435 is_table: bool = False, 436 normalize: t.Optional[bool] = None, 437 ) -> str: 438 return normalize_name( 439 name, 440 dialect=dialect or self.dialect, 441 is_table=is_table, 442 normalize=self.normalize if normalize is None else normalize, 443 ).name 444 445 def depth(self) -> int: 446 if not self.empty and not self._depth: 447 # The columns themselves are a mapping, but we don't want to include those 448 self._depth = super().depth() - 1 449 return self._depth 450 451 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 452 """ 453 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 454 455 Args: 456 schema_type: the type we want to convert. 457 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 458 459 Returns: 460 The resulting expression type. 461 """ 462 if schema_type not in self._type_mapping_cache: 463 dialect = dialect or self.dialect 464 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 465 466 try: 467 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 468 self._type_mapping_cache[schema_type] = expression 469 except AttributeError: 470 in_dialect = f" in dialect {dialect}" if dialect else "" 471 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 472 473 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
221 def __init__( 222 self, 223 schema: t.Optional[t.Dict] = None, 224 visible: t.Optional[t.Dict] = None, 225 dialect: DialectType = None, 226 normalize: bool = True, 227 ) -> None: 228 self.dialect = dialect 229 self.visible = {} if visible is None else visible 230 self.normalize = normalize 231 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 232 self._depth = 0 233 schema = {} if schema is None else schema 234 235 super().__init__(self._normalize(schema) if self.normalize else schema)
246 def find( 247 self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False 248 ) -> t.Optional[t.Any]: 249 schema = super().find( 250 table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types 251 ) 252 if ensure_data_types and isinstance(schema, dict): 253 schema = { 254 col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype 255 for col, dtype in schema.items() 256 } 257 258 return schema
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether to raise in case the schema is not found.
- ensure_data_types: whether to convert
str
types to theirDataType
equivalents.
Returns:
The schema of the target table.
271 def add_table( 272 self, 273 table: exp.Table | str, 274 column_mapping: t.Optional[ColumnMapping] = None, 275 dialect: DialectType = None, 276 normalize: t.Optional[bool] = None, 277 match_depth: bool = True, 278 ) -> None: 279 """ 280 Register or update a table. Updates are only performed if a new column mapping is provided. 281 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 282 283 Args: 284 table: the `Table` expression instance or string representing the table. 285 column_mapping: a column mapping that describes the structure of the table. 286 dialect: the SQL dialect that will be used to parse `table` if it's a string. 287 normalize: whether to normalize identifiers according to the dialect of interest. 288 match_depth: whether to enforce that the table must match the schema's depth or not. 289 """ 290 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 291 292 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 293 raise SchemaError( 294 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 295 f"schema's nesting level: {self.depth()}." 296 ) 297 298 normalized_column_mapping = { 299 self._normalize_name(key, dialect=dialect, normalize=normalize): value 300 for key, value in ensure_column_mapping(column_mapping).items() 301 } 302 303 schema = self.find(normalized_table, raise_on_missing=False) 304 if schema and not normalized_column_mapping: 305 return 306 307 parts = self.table_parts(normalized_table) 308 309 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 310 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
312 def column_names( 313 self, 314 table: exp.Table | str, 315 only_visible: bool = False, 316 dialect: DialectType = None, 317 normalize: t.Optional[bool] = None, 318 ) -> t.List[str]: 319 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 320 321 schema = self.find(normalized_table) 322 if schema is None: 323 return [] 324 325 if not only_visible or not self.visible: 326 return list(schema) 327 328 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 329 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
331 def get_column_type( 332 self, 333 table: exp.Table | str, 334 column: exp.Column | str, 335 dialect: DialectType = None, 336 normalize: t.Optional[bool] = None, 337 ) -> exp.DataType: 338 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 339 340 normalized_column_name = self._normalize_name( 341 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 342 ) 343 344 table_schema = self.find(normalized_table, raise_on_missing=False) 345 if table_schema: 346 column_type = table_schema.get(normalized_column_name) 347 348 if isinstance(column_type, exp.DataType): 349 return column_type 350 elif isinstance(column_type, str): 351 return self._to_data_type(column_type, dialect=dialect) 352 353 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
355 def has_column( 356 self, 357 table: exp.Table | str, 358 column: exp.Column | str, 359 dialect: DialectType = None, 360 normalize: t.Optional[bool] = None, 361 ) -> bool: 362 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 363 364 normalized_column_name = self._normalize_name( 365 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 366 ) 367 368 table_schema = self.find(normalized_table, raise_on_missing=False) 369 return normalized_column_name in table_schema if table_schema else False
Returns whether column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
Inherited Members
476def normalize_name( 477 identifier: str | exp.Identifier, 478 dialect: DialectType = None, 479 is_table: bool = False, 480 normalize: t.Optional[bool] = True, 481) -> exp.Identifier: 482 if isinstance(identifier, str): 483 identifier = exp.parse_identifier(identifier, dialect=dialect) 484 485 if not normalize: 486 return identifier 487 488 # this is used for normalize_identifier, bigquery has special rules pertaining tables 489 identifier.meta["is_table"] = is_table 490 return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
500def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 501 if mapping is None: 502 return {} 503 elif isinstance(mapping, dict): 504 return mapping 505 elif isinstance(mapping, str): 506 col_name_type_strs = [x.strip() for x in mapping.split(",")] 507 return { 508 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 509 for name_type_str in col_name_type_strs 510 } 511 elif isinstance(mapping, list): 512 return {x.strip(): None for x in mapping} 513 514 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
517def flatten_schema( 518 schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None 519) -> t.List[t.List[str]]: 520 tables = [] 521 keys = keys or [] 522 depth = dict_depth(schema) - 1 if depth is None else depth 523 524 for k, v in schema.items(): 525 if depth == 1 or not isinstance(v, dict): 526 tables.append(keys + [k]) 527 elif depth >= 2: 528 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 529 530 return tables
533def nested_get( 534 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 535) -> t.Optional[t.Any]: 536 """ 537 Get a value for a nested dictionary. 538 539 Args: 540 d: the dictionary to search. 541 *path: tuples of (name, key), where: 542 `key` is the key in the dictionary to get. 543 `name` is a string to use in the error if `key` isn't found. 544 545 Returns: 546 The value or None if it doesn't exist. 547 """ 548 for name, key in path: 549 d = d.get(key) # type: ignore 550 if d is None: 551 if raise_on_missing: 552 name = "table" if name == "this" else name 553 raise ValueError(f"Unknown {name}: {key}") 554 return None 555 556 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
559def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 560 """ 561 In-place set a value for a nested dictionary 562 563 Example: 564 >>> nested_set({}, ["top_key", "second_key"], "value") 565 {'top_key': {'second_key': 'value'}} 566 567 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 568 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 569 570 Args: 571 d: dictionary to update. 572 keys: the keys that makeup the path to `value`. 573 value: the value to set in the dictionary for the given key path. 574 575 Returns: 576 The (possibly) updated dictionary. 577 """ 578 if not keys: 579 return d 580 581 if len(keys) == 1: 582 d[keys[0]] = value 583 return d 584 585 subd = d 586 for key in keys[:-1]: 587 if key not in subd: 588 subd = subd.setdefault(key, {}) 589 else: 590 subd = subd[key] 591 592 subd[keys[-1]] = value 593 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.