sqlglot.dataframe.sql
1from sqlglot.dataframe.sql.column import Column 2from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions 3from sqlglot.dataframe.sql.group import GroupedData 4from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter 5from sqlglot.dataframe.sql.session import SparkSession 6from sqlglot.dataframe.sql.window import Window, WindowSpec 7 8__all__ = [ 9 "SparkSession", 10 "DataFrame", 11 "GroupedData", 12 "Column", 13 "DataFrameNaFunctions", 14 "Window", 15 "WindowSpec", 16 "DataFrameReader", 17 "DataFrameWriter", 18]
23class SparkSession: 24 DEFAULT_DIALECT = "spark" 25 _instance = None 26 27 def __init__(self): 28 if not hasattr(self, "known_ids"): 29 self.known_ids = set() 30 self.known_branch_ids = set() 31 self.known_sequence_ids = set() 32 self.name_to_sequence_id_mapping = defaultdict(list) 33 self.incrementing_id = 1 34 self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT) 35 36 def __new__(cls, *args, **kwargs) -> SparkSession: 37 if cls._instance is None: 38 cls._instance = super().__new__(cls) 39 return cls._instance 40 41 @property 42 def read(self) -> DataFrameReader: 43 return DataFrameReader(self) 44 45 def table(self, tableName: str) -> DataFrame: 46 return self.read.table(tableName) 47 48 def createDataFrame( 49 self, 50 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 51 schema: t.Optional[SchemaInput] = None, 52 samplingRatio: t.Optional[float] = None, 53 verifySchema: bool = False, 54 ) -> DataFrame: 55 from sqlglot.dataframe.sql.dataframe import DataFrame 56 57 if samplingRatio is not None or verifySchema: 58 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 59 if schema is not None and ( 60 not isinstance(schema, (StructType, str, list)) 61 or (isinstance(schema, list) and not isinstance(schema[0], str)) 62 ): 63 raise NotImplementedError("Only schema of either list or string of list supported") 64 if not data: 65 raise ValueError("Must provide data to create into a DataFrame") 66 67 column_mapping: t.Dict[str, t.Optional[str]] 68 if schema is not None: 69 column_mapping = get_column_mapping_from_schema_input(schema) 70 elif isinstance(data[0], dict): 71 column_mapping = {col_name.strip(): None for col_name in data[0]} 72 else: 73 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 74 75 data_expressions = [ 76 exp.tuple_( 77 *map( 78 lambda x: F.lit(x).expression, 79 row if not isinstance(row, dict) else row.values(), 80 ) 81 ) 82 for row in data 83 ] 84 85 sel_columns = [ 86 ( 87 F.col(name).cast(data_type).alias(name).expression 88 if data_type is not None 89 else F.col(name).expression 90 ) 91 for name, data_type in column_mapping.items() 92 ] 93 94 select_kwargs = { 95 "expressions": sel_columns, 96 "from": exp.From( 97 this=exp.Values( 98 expressions=data_expressions, 99 alias=exp.TableAlias( 100 this=exp.to_identifier(self._auto_incrementing_name), 101 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 102 ), 103 ), 104 ), 105 } 106 107 sel_expression = exp.Select(**select_kwargs) 108 return DataFrame(self, sel_expression) 109 110 def _optimize( 111 self, expression: exp.Expression, dialect: t.Optional[Dialect] = None 112 ) -> exp.Expression: 113 dialect = dialect or self.dialect 114 quote_identifiers(expression, dialect=dialect) 115 return optimize(expression, dialect=dialect) 116 117 def sql(self, sqlQuery: str) -> DataFrame: 118 expression = self._optimize(sqlglot.parse_one(sqlQuery, read=self.dialect)) 119 if isinstance(expression, exp.Select): 120 df = DataFrame(self, expression) 121 df = df._convert_leaf_to_cte() 122 elif isinstance(expression, (exp.Create, exp.Insert)): 123 select_expression = expression.expression.copy() 124 if isinstance(expression, exp.Insert): 125 select_expression.set("with", expression.args.get("with")) 126 expression.set("with", None) 127 del expression.args["expression"] 128 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 129 df = df._convert_leaf_to_cte() 130 else: 131 raise ValueError( 132 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 133 ) 134 return df 135 136 @property 137 def _auto_incrementing_name(self) -> str: 138 name = f"a{self.incrementing_id}" 139 self.incrementing_id += 1 140 return name 141 142 @property 143 def _random_branch_id(self) -> str: 144 id = self._random_id 145 self.known_branch_ids.add(id) 146 return id 147 148 @property 149 def _random_sequence_id(self): 150 id = self._random_id 151 self.known_sequence_ids.add(id) 152 return id 153 154 @property 155 def _random_id(self) -> str: 156 id = "r" + uuid.uuid4().hex 157 self.known_ids.add(id) 158 return id 159 160 @property 161 def _join_hint_names(self) -> t.Set[str]: 162 return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"} 163 164 def _add_alias_to_mapping(self, name: str, sequence_id: str): 165 self.name_to_sequence_id_mapping[name].append(sequence_id) 166 167 class Builder: 168 SQLFRAME_DIALECT_KEY = "sqlframe.dialect" 169 170 def __init__(self): 171 self.dialect = "spark" 172 173 def __getattr__(self, item) -> SparkSession.Builder: 174 return self 175 176 def __call__(self, *args, **kwargs): 177 return self 178 179 def config( 180 self, 181 key: t.Optional[str] = None, 182 value: t.Optional[t.Any] = None, 183 *, 184 map: t.Optional[t.Dict[str, t.Any]] = None, 185 **kwargs: t.Any, 186 ) -> SparkSession.Builder: 187 if key == self.SQLFRAME_DIALECT_KEY: 188 self.dialect = value 189 elif map and self.SQLFRAME_DIALECT_KEY in map: 190 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 191 return self 192 193 def getOrCreate(self) -> SparkSession: 194 spark = SparkSession() 195 spark.dialect = Dialect.get_or_raise(self.dialect) 196 return spark 197 198 @classproperty 199 def builder(cls) -> Builder: 200 return cls.Builder()
48 def createDataFrame( 49 self, 50 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 51 schema: t.Optional[SchemaInput] = None, 52 samplingRatio: t.Optional[float] = None, 53 verifySchema: bool = False, 54 ) -> DataFrame: 55 from sqlglot.dataframe.sql.dataframe import DataFrame 56 57 if samplingRatio is not None or verifySchema: 58 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 59 if schema is not None and ( 60 not isinstance(schema, (StructType, str, list)) 61 or (isinstance(schema, list) and not isinstance(schema[0], str)) 62 ): 63 raise NotImplementedError("Only schema of either list or string of list supported") 64 if not data: 65 raise ValueError("Must provide data to create into a DataFrame") 66 67 column_mapping: t.Dict[str, t.Optional[str]] 68 if schema is not None: 69 column_mapping = get_column_mapping_from_schema_input(schema) 70 elif isinstance(data[0], dict): 71 column_mapping = {col_name.strip(): None for col_name in data[0]} 72 else: 73 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 74 75 data_expressions = [ 76 exp.tuple_( 77 *map( 78 lambda x: F.lit(x).expression, 79 row if not isinstance(row, dict) else row.values(), 80 ) 81 ) 82 for row in data 83 ] 84 85 sel_columns = [ 86 ( 87 F.col(name).cast(data_type).alias(name).expression 88 if data_type is not None 89 else F.col(name).expression 90 ) 91 for name, data_type in column_mapping.items() 92 ] 93 94 select_kwargs = { 95 "expressions": sel_columns, 96 "from": exp.From( 97 this=exp.Values( 98 expressions=data_expressions, 99 alias=exp.TableAlias( 100 this=exp.to_identifier(self._auto_incrementing_name), 101 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 102 ), 103 ), 104 ), 105 } 106 107 sel_expression = exp.Select(**select_kwargs) 108 return DataFrame(self, sel_expression)
117 def sql(self, sqlQuery: str) -> DataFrame: 118 expression = self._optimize(sqlglot.parse_one(sqlQuery, read=self.dialect)) 119 if isinstance(expression, exp.Select): 120 df = DataFrame(self, expression) 121 df = df._convert_leaf_to_cte() 122 elif isinstance(expression, (exp.Create, exp.Insert)): 123 select_expression = expression.expression.copy() 124 if isinstance(expression, exp.Insert): 125 select_expression.set("with", expression.args.get("with")) 126 expression.set("with", None) 127 del expression.args["expression"] 128 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 129 df = df._convert_leaf_to_cte() 130 else: 131 raise ValueError( 132 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 133 ) 134 return df
167 class Builder: 168 SQLFRAME_DIALECT_KEY = "sqlframe.dialect" 169 170 def __init__(self): 171 self.dialect = "spark" 172 173 def __getattr__(self, item) -> SparkSession.Builder: 174 return self 175 176 def __call__(self, *args, **kwargs): 177 return self 178 179 def config( 180 self, 181 key: t.Optional[str] = None, 182 value: t.Optional[t.Any] = None, 183 *, 184 map: t.Optional[t.Dict[str, t.Any]] = None, 185 **kwargs: t.Any, 186 ) -> SparkSession.Builder: 187 if key == self.SQLFRAME_DIALECT_KEY: 188 self.dialect = value 189 elif map and self.SQLFRAME_DIALECT_KEY in map: 190 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 191 return self 192 193 def getOrCreate(self) -> SparkSession: 194 spark = SparkSession() 195 spark.dialect = Dialect.get_or_raise(self.dialect) 196 return spark
179 def config( 180 self, 181 key: t.Optional[str] = None, 182 value: t.Optional[t.Any] = None, 183 *, 184 map: t.Optional[t.Dict[str, t.Any]] = None, 185 **kwargs: t.Any, 186 ) -> SparkSession.Builder: 187 if key == self.SQLFRAME_DIALECT_KEY: 188 self.dialect = value 189 elif map and self.SQLFRAME_DIALECT_KEY in map: 190 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 191 return self
47class DataFrame: 48 def __init__( 49 self, 50 spark: SparkSession, 51 expression: exp.Select, 52 branch_id: t.Optional[str] = None, 53 sequence_id: t.Optional[str] = None, 54 last_op: Operation = Operation.INIT, 55 pending_hints: t.Optional[t.List[exp.Expression]] = None, 56 output_expression_container: t.Optional[OutputExpressionContainer] = None, 57 **kwargs, 58 ): 59 self.spark = spark 60 self.expression = expression 61 self.branch_id = branch_id or self.spark._random_branch_id 62 self.sequence_id = sequence_id or self.spark._random_sequence_id 63 self.last_op = last_op 64 self.pending_hints = pending_hints or [] 65 self.output_expression_container = output_expression_container or exp.Select() 66 67 def __getattr__(self, column_name: str) -> Column: 68 return self[column_name] 69 70 def __getitem__(self, column_name: str) -> Column: 71 column_name = f"{self.branch_id}.{column_name}" 72 return Column(column_name) 73 74 def __copy__(self): 75 return self.copy() 76 77 @property 78 def sparkSession(self): 79 return self.spark 80 81 @property 82 def write(self): 83 return DataFrameWriter(self) 84 85 @property 86 def latest_cte_name(self) -> str: 87 if not self.expression.ctes: 88 from_exp = self.expression.args["from"] 89 if from_exp.alias_or_name: 90 return from_exp.alias_or_name 91 table_alias = from_exp.find(exp.TableAlias) 92 if not table_alias: 93 raise RuntimeError( 94 f"Could not find an alias name for this expression: {self.expression}" 95 ) 96 return table_alias.alias_or_name 97 return self.expression.ctes[-1].alias 98 99 @property 100 def pending_join_hints(self): 101 return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)] 102 103 @property 104 def pending_partition_hints(self): 105 return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)] 106 107 @property 108 def columns(self) -> t.List[str]: 109 return self.expression.named_selects 110 111 @property 112 def na(self) -> DataFrameNaFunctions: 113 return DataFrameNaFunctions(self) 114 115 def _replace_cte_names_with_hashes(self, expression: exp.Select): 116 replacement_mapping = {} 117 for cte in expression.ctes: 118 old_name_id = cte.args["alias"].this 119 new_hashed_id = exp.to_identifier( 120 self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] 121 ) 122 replacement_mapping[old_name_id] = new_hashed_id 123 expression = expression.transform(replace_id_value, replacement_mapping).assert_is( 124 exp.Select 125 ) 126 return expression 127 128 def _create_cte_from_expression( 129 self, 130 expression: exp.Expression, 131 branch_id: t.Optional[str] = None, 132 sequence_id: t.Optional[str] = None, 133 **kwargs, 134 ) -> t.Tuple[exp.CTE, str]: 135 name = self._create_hash_from_expression(expression) 136 expression_to_cte = expression.copy() 137 expression_to_cte.set("with", None) 138 cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] 139 cte.set("branch_id", branch_id or self.branch_id) 140 cte.set("sequence_id", sequence_id or self.sequence_id) 141 return cte, name 142 143 @t.overload 144 def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ... 145 146 @t.overload 147 def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ... 148 149 def _ensure_list_of_columns(self, cols): 150 return Column.ensure_cols(ensure_list(cols)) 151 152 def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None): 153 cols = self._ensure_list_of_columns(cols) 154 normalize(self.spark, expression or self.expression, cols) 155 return cols 156 157 def _ensure_and_normalize_col(self, col): 158 col = Column.ensure_col(col) 159 normalize(self.spark, self.expression, col) 160 return col 161 162 def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame: 163 df = self._resolve_pending_hints() 164 sequence_id = sequence_id or df.sequence_id 165 expression = df.expression.copy() 166 cte_expression, cte_name = df._create_cte_from_expression( 167 expression=expression, sequence_id=sequence_id 168 ) 169 new_expression = df._add_ctes_to_expression( 170 exp.Select(), expression.ctes + [cte_expression] 171 ) 172 sel_columns = df._get_outer_select_columns(cte_expression) 173 new_expression = new_expression.from_(cte_name).select( 174 *[x.alias_or_name for x in sel_columns] 175 ) 176 return df.copy(expression=new_expression, sequence_id=sequence_id) 177 178 def _resolve_pending_hints(self) -> DataFrame: 179 df = self.copy() 180 if not self.pending_hints: 181 return df 182 expression = df.expression 183 hint_expression = expression.args.get("hint") or exp.Hint(expressions=[]) 184 for hint in df.pending_partition_hints: 185 hint_expression.append("expressions", hint) 186 df.pending_hints.remove(hint) 187 188 join_aliases = { 189 join_table.alias_or_name 190 for join_table in get_tables_from_expression_with_join(expression) 191 } 192 if join_aliases: 193 for hint in df.pending_join_hints: 194 for sequence_id_expression in hint.expressions: 195 sequence_id_or_name = sequence_id_expression.alias_or_name 196 sequence_ids_to_match = [sequence_id_or_name] 197 if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: 198 sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[ 199 sequence_id_or_name 200 ] 201 matching_ctes = [ 202 cte 203 for cte in reversed(expression.ctes) 204 if cte.args["sequence_id"] in sequence_ids_to_match 205 ] 206 for matching_cte in matching_ctes: 207 if matching_cte.alias_or_name in join_aliases: 208 sequence_id_expression.set("this", matching_cte.args["alias"].this) 209 df.pending_hints.remove(hint) 210 break 211 hint_expression.append("expressions", hint) 212 if hint_expression.expressions: 213 expression.set("hint", hint_expression) 214 return df 215 216 def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: 217 hint_name = hint_name.upper() 218 hint_expression = ( 219 exp.JoinHint( 220 this=hint_name, 221 expressions=[exp.to_table(parameter.alias_or_name) for parameter in args], 222 ) 223 if hint_name in JOIN_HINTS 224 else exp.Anonymous( 225 this=hint_name, expressions=[parameter.expression for parameter in args] 226 ) 227 ) 228 new_df = self.copy() 229 new_df.pending_hints.append(hint_expression) 230 return new_df 231 232 def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool): 233 other_df = other._convert_leaf_to_cte() 234 base_expression = self.expression.copy() 235 base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes) 236 all_ctes = base_expression.ctes 237 other_df.expression.set("with", None) 238 base_expression.set("with", None) 239 operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression) 240 operation.set("with", exp.With(expressions=all_ctes)) 241 return self.copy(expression=operation)._convert_leaf_to_cte() 242 243 def _cache(self, storage_level: str): 244 df = self._convert_leaf_to_cte() 245 df.expression.ctes[-1].set("cache_storage_level", storage_level) 246 return df 247 248 @classmethod 249 def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select: 250 expression = expression.copy() 251 with_expression = expression.args.get("with") 252 if with_expression: 253 existing_ctes = with_expression.expressions 254 existsing_cte_names = {x.alias_or_name for x in existing_ctes} 255 for cte in ctes: 256 if cte.alias_or_name not in existsing_cte_names: 257 existing_ctes.append(cte) 258 else: 259 existing_ctes = ctes 260 expression.set("with", exp.With(expressions=existing_ctes)) 261 return expression 262 263 @classmethod 264 def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]: 265 expression = item.expression if isinstance(item, DataFrame) else item 266 return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions] 267 268 @classmethod 269 def _create_hash_from_expression(cls, expression: exp.Expression) -> str: 270 from sqlglot.dataframe.sql.session import SparkSession 271 272 value = expression.sql(dialect=SparkSession().dialect).encode("utf-8") 273 return f"t{zlib.crc32(value)}"[:6] 274 275 def _get_select_expressions( 276 self, 277 ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: 278 select_expressions: t.List[ 279 t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select] 280 ] = [] 281 main_select_ctes: t.List[exp.CTE] = [] 282 for cte in self.expression.ctes: 283 cache_storage_level = cte.args.get("cache_storage_level") 284 if cache_storage_level: 285 select_expression = cte.this.copy() 286 select_expression.set("with", exp.With(expressions=copy(main_select_ctes))) 287 select_expression.set("cte_alias_name", cte.alias_or_name) 288 select_expression.set("cache_storage_level", cache_storage_level) 289 select_expressions.append((exp.Cache, select_expression)) 290 else: 291 main_select_ctes.append(cte) 292 main_select = self.expression.copy() 293 if main_select_ctes: 294 main_select.set("with", exp.With(expressions=main_select_ctes)) 295 expression_select_pair = (type(self.output_expression_container), main_select) 296 select_expressions.append(expression_select_pair) # type: ignore 297 return select_expressions 298 299 def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]: 300 from sqlglot.dataframe.sql.session import SparkSession 301 302 dialect = Dialect.get_or_raise(dialect or SparkSession().dialect) 303 304 df = self._resolve_pending_hints() 305 select_expressions = df._get_select_expressions() 306 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 307 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 308 309 for expression_type, select_expression in select_expressions: 310 select_expression = select_expression.transform( 311 replace_id_value, replacement_mapping 312 ).assert_is(exp.Select) 313 if optimize: 314 select_expression = t.cast( 315 exp.Select, self.spark._optimize(select_expression, dialect=dialect) 316 ) 317 318 select_expression = df._replace_cte_names_with_hashes(select_expression) 319 320 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 321 if expression_type == exp.Cache: 322 cache_table_name = df._create_hash_from_expression(select_expression) 323 cache_table = exp.to_table(cache_table_name) 324 original_alias_name = select_expression.args["cte_alias_name"] 325 326 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 327 cache_table_name 328 ) 329 sqlglot.schema.add_table( 330 cache_table_name, 331 { 332 expression.alias_or_name: expression.type.sql(dialect=dialect) 333 for expression in select_expression.expressions 334 }, 335 dialect=dialect, 336 ) 337 338 cache_storage_level = select_expression.args["cache_storage_level"] 339 options = [ 340 exp.Literal.string("storageLevel"), 341 exp.Literal.string(cache_storage_level), 342 ] 343 expression = exp.Cache( 344 this=cache_table, expression=select_expression, lazy=True, options=options 345 ) 346 347 # We will drop the "view" if it exists before running the cache table 348 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 349 elif expression_type == exp.Create: 350 expression = df.output_expression_container.copy() 351 expression.set("expression", select_expression) 352 elif expression_type == exp.Insert: 353 expression = df.output_expression_container.copy() 354 select_without_ctes = select_expression.copy() 355 select_without_ctes.set("with", None) 356 expression.set("expression", select_without_ctes) 357 358 if select_expression.ctes: 359 expression.set("with", exp.With(expressions=select_expression.ctes)) 360 elif expression_type == exp.Select: 361 expression = select_expression 362 else: 363 raise ValueError(f"Invalid expression type: {expression_type}") 364 365 output_expressions.append(expression) 366 367 return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions] 368 369 def copy(self, **kwargs) -> DataFrame: 370 return DataFrame(**object_to_dict(self, **kwargs)) 371 372 @operation(Operation.SELECT) 373 def select(self, *cols, **kwargs) -> DataFrame: 374 cols = self._ensure_and_normalize_cols(cols) 375 kwargs["append"] = kwargs.get("append", False) 376 if self.expression.args.get("joins"): 377 ambiguous_cols = [ 378 col 379 for col in cols 380 if isinstance(col.column_expression, exp.Column) and not col.column_expression.table 381 ] 382 if ambiguous_cols: 383 join_table_identifiers = [ 384 x.this for x in get_tables_from_expression_with_join(self.expression) 385 ] 386 cte_names_in_join = [x.this for x in join_table_identifiers] 387 # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right 388 # and therefore we allow multiple columns with the same name in the result. This matches the behavior 389 # of Spark. 390 resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} 391 for ambiguous_col in ambiguous_cols: 392 ctes_with_column = [ 393 cte 394 for cte in self.expression.ctes 395 if cte.alias_or_name in cte_names_in_join 396 and ambiguous_col.alias_or_name in cte.this.named_selects 397 ] 398 # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, 399 # use the same CTE we used before 400 cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) 401 if cte: 402 resolved_column_position[ambiguous_col] += 1 403 else: 404 cte = ctes_with_column[resolved_column_position[ambiguous_col]] 405 ambiguous_col.expression.set("table", cte.alias_or_name) 406 return self.copy( 407 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 408 ) 409 410 @operation(Operation.NO_OP) 411 def alias(self, name: str, **kwargs) -> DataFrame: 412 new_sequence_id = self.spark._random_sequence_id 413 df = self.copy() 414 for join_hint in df.pending_join_hints: 415 for expression in join_hint.expressions: 416 if expression.alias_or_name == self.sequence_id: 417 expression.set("this", Column.ensure_col(new_sequence_id).expression) 418 df.spark._add_alias_to_mapping(name, new_sequence_id) 419 return df._convert_leaf_to_cte(sequence_id=new_sequence_id) 420 421 @operation(Operation.WHERE) 422 def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame: 423 col = self._ensure_and_normalize_col(column) 424 return self.copy(expression=self.expression.where(col.expression)) 425 426 filter = where 427 428 @operation(Operation.GROUP_BY) 429 def groupBy(self, *cols, **kwargs) -> GroupedData: 430 columns = self._ensure_and_normalize_cols(cols) 431 return GroupedData(self, columns, self.last_op) 432 433 @operation(Operation.SELECT) 434 def agg(self, *exprs, **kwargs) -> DataFrame: 435 cols = self._ensure_and_normalize_cols(exprs) 436 return self.groupBy().agg(*cols) 437 438 @operation(Operation.FROM) 439 def join( 440 self, 441 other_df: DataFrame, 442 on: t.Union[str, t.List[str], Column, t.List[Column]], 443 how: str = "inner", 444 **kwargs, 445 ) -> DataFrame: 446 other_df = other_df._convert_leaf_to_cte() 447 join_columns = self._ensure_list_of_columns(on) 448 # We will determine actual "join on" expression later so we don't provide it at first 449 join_expression = self.expression.join( 450 other_df.latest_cte_name, join_type=how.replace("_", " ") 451 ) 452 join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) 453 self_columns = self._get_outer_select_columns(join_expression) 454 other_columns = self._get_outer_select_columns(other_df) 455 # Determines the join clause and select columns to be used passed on what type of columns were provided for 456 # the join. The columns returned changes based on how the on expression is provided. 457 if isinstance(join_columns[0].expression, exp.Column): 458 """ 459 Unique characteristics of join on column names only: 460 * The column names are put at the front of the select list 461 * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) 462 """ 463 table_names = [ 464 table.alias_or_name 465 for table in get_tables_from_expression_with_join(join_expression) 466 ] 467 potential_ctes = [ 468 cte 469 for cte in join_expression.ctes 470 if cte.alias_or_name in table_names 471 and cte.alias_or_name != other_df.latest_cte_name 472 ] 473 # Determine the table to reference for the left side of the join by checking each of the left side 474 # tables and see if they have the column being referenced. 475 join_column_pairs = [] 476 for join_column in join_columns: 477 num_matching_ctes = 0 478 for cte in potential_ctes: 479 if join_column.alias_or_name in cte.this.named_selects: 480 left_column = join_column.copy().set_table_name(cte.alias_or_name) 481 right_column = join_column.copy().set_table_name(other_df.latest_cte_name) 482 join_column_pairs.append((left_column, right_column)) 483 num_matching_ctes += 1 484 if num_matching_ctes > 1: 485 raise ValueError( 486 f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." 487 ) 488 elif num_matching_ctes == 0: 489 raise ValueError( 490 f"Column {join_column.alias_or_name} does not exist in any of the tables." 491 ) 492 join_clause = functools.reduce( 493 lambda x, y: x & y, 494 [left_column == right_column for left_column, right_column in join_column_pairs], 495 ) 496 join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] 497 # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list 498 select_column_names = [ 499 ( 500 column.alias_or_name 501 if not isinstance(column.expression.this, exp.Star) 502 else column.sql() 503 ) 504 for column in self_columns + other_columns 505 ] 506 select_column_names = [ 507 column_name 508 for column_name in select_column_names 509 if column_name not in join_column_names 510 ] 511 select_column_names = join_column_names + select_column_names 512 else: 513 """ 514 Unique characteristics of join on expressions: 515 * There is no deduplication of the results. 516 * The left join dataframe columns go first and right come after. No sort preference is given to join columns 517 """ 518 join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) 519 if len(join_columns) > 1: 520 join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] 521 join_clause = join_columns[0] 522 select_column_names = [column.alias_or_name for column in self_columns + other_columns] 523 524 # Update the on expression with the actual join clause to replace the dummy one from before 525 join_expression.args["joins"][-1].set("on", join_clause.expression) 526 new_df = self.copy(expression=join_expression) 527 new_df.pending_join_hints.extend(self.pending_join_hints) 528 new_df.pending_hints.extend(other_df.pending_hints) 529 new_df = new_df.select.__wrapped__(new_df, *select_column_names) 530 return new_df 531 532 @operation(Operation.ORDER_BY) 533 def orderBy( 534 self, 535 *cols: t.Union[str, Column], 536 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 537 ) -> DataFrame: 538 """ 539 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 540 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 541 is unlikely to come up. 542 """ 543 columns = self._ensure_and_normalize_cols(cols) 544 pre_ordered_col_indexes = [ 545 i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) 546 ] 547 if ascending is None: 548 ascending = [True] * len(columns) 549 elif not isinstance(ascending, list): 550 ascending = [ascending] * len(columns) 551 ascending = [bool(x) for i, x in enumerate(ascending)] 552 assert len(columns) == len( 553 ascending 554 ), "The length of items in ascending must equal the number of columns provided" 555 col_and_ascending = list(zip(columns, ascending)) 556 order_by_columns = [ 557 ( 558 exp.Ordered(this=col.expression, desc=not asc) 559 if i not in pre_ordered_col_indexes 560 else columns[i].column_expression 561 ) 562 for i, (col, asc) in enumerate(col_and_ascending) 563 ] 564 return self.copy(expression=self.expression.order_by(*order_by_columns)) 565 566 sort = orderBy 567 568 @operation(Operation.FROM) 569 def union(self, other: DataFrame) -> DataFrame: 570 return self._set_operation(exp.Union, other, False) 571 572 unionAll = union 573 574 @operation(Operation.FROM) 575 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 576 l_columns = self.columns 577 r_columns = other.columns 578 if not allowMissingColumns: 579 l_expressions = l_columns 580 r_expressions = l_columns 581 else: 582 l_expressions = [] 583 r_expressions = [] 584 r_columns_unused = copy(r_columns) 585 for l_column in l_columns: 586 l_expressions.append(l_column) 587 if l_column in r_columns: 588 r_expressions.append(l_column) 589 r_columns_unused.remove(l_column) 590 else: 591 r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) 592 for r_column in r_columns_unused: 593 l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) 594 r_expressions.append(r_column) 595 r_df = ( 596 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 597 ) 598 l_df = self.copy() 599 if allowMissingColumns: 600 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 601 return l_df._set_operation(exp.Union, r_df, False) 602 603 @operation(Operation.FROM) 604 def intersect(self, other: DataFrame) -> DataFrame: 605 return self._set_operation(exp.Intersect, other, True) 606 607 @operation(Operation.FROM) 608 def intersectAll(self, other: DataFrame) -> DataFrame: 609 return self._set_operation(exp.Intersect, other, False) 610 611 @operation(Operation.FROM) 612 def exceptAll(self, other: DataFrame) -> DataFrame: 613 return self._set_operation(exp.Except, other, False) 614 615 @operation(Operation.SELECT) 616 def distinct(self) -> DataFrame: 617 return self.copy(expression=self.expression.distinct()) 618 619 @operation(Operation.SELECT) 620 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 621 if not subset: 622 return self.distinct() 623 column_names = ensure_list(subset) 624 window = Window.partitionBy(*column_names).orderBy(*column_names) 625 return ( 626 self.copy() 627 .withColumn("row_num", F.row_number().over(window)) 628 .where(F.col("row_num") == F.lit(1)) 629 .drop("row_num") 630 ) 631 632 @operation(Operation.FROM) 633 def dropna( 634 self, 635 how: str = "any", 636 thresh: t.Optional[int] = None, 637 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 638 ) -> DataFrame: 639 minimum_non_null = thresh or 0 # will be determined later if thresh is null 640 new_df = self.copy() 641 all_columns = self._get_outer_select_columns(new_df.expression) 642 if subset: 643 null_check_columns = self._ensure_and_normalize_cols(subset) 644 else: 645 null_check_columns = all_columns 646 if thresh is None: 647 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 648 else: 649 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 650 if minimum_num_nulls > len(null_check_columns): 651 raise RuntimeError( 652 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 653 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 654 ) 655 if_null_checks = [ 656 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 657 ] 658 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 659 num_nulls = nulls_added_together.alias("num_nulls") 660 new_df = new_df.select(num_nulls, append=True) 661 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 662 final_df = filtered_df.select(*all_columns) 663 return final_df 664 665 @operation(Operation.FROM) 666 def fillna( 667 self, 668 value: t.Union[ColumnLiterals], 669 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 670 ) -> DataFrame: 671 """ 672 Functionality Difference: If you provide a value to replace a null and that type conflicts 673 with the type of the column then PySpark will just ignore your replacement. 674 This will try to cast them to be the same in some cases. So they won't always match. 675 Best to not mix types so make sure replacement is the same type as the column 676 677 Possibility for improvement: Use `typeof` function to get the type of the column 678 and check if it matches the type of the value provided. If not then make it null. 679 """ 680 from sqlglot.dataframe.sql.functions import lit 681 682 values = None 683 columns = None 684 new_df = self.copy() 685 all_columns = self._get_outer_select_columns(new_df.expression) 686 all_column_mapping = {column.alias_or_name: column for column in all_columns} 687 if isinstance(value, dict): 688 values = list(value.values()) 689 columns = self._ensure_and_normalize_cols(list(value)) 690 if not columns: 691 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 692 if not values: 693 values = [value] * len(columns) 694 value_columns = [lit(value) for value in values] 695 696 null_replacement_mapping = { 697 column.alias_or_name: ( 698 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 699 ) 700 for column, value in zip(columns, value_columns) 701 } 702 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 703 null_replacement_columns = [ 704 null_replacement_mapping[column.alias_or_name] for column in all_columns 705 ] 706 new_df = new_df.select(*null_replacement_columns) 707 return new_df 708 709 @operation(Operation.FROM) 710 def replace( 711 self, 712 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 713 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 714 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 715 ) -> DataFrame: 716 from sqlglot.dataframe.sql.functions import lit 717 718 old_values = None 719 new_df = self.copy() 720 all_columns = self._get_outer_select_columns(new_df.expression) 721 all_column_mapping = {column.alias_or_name: column for column in all_columns} 722 723 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 724 if isinstance(to_replace, dict): 725 old_values = list(to_replace) 726 new_values = list(to_replace.values()) 727 elif not old_values and isinstance(to_replace, list): 728 assert isinstance(value, list), "value must be a list since the replacements are a list" 729 assert len(to_replace) == len( 730 value 731 ), "the replacements and values must be the same length" 732 old_values = to_replace 733 new_values = value 734 else: 735 old_values = [to_replace] * len(columns) 736 new_values = [value] * len(columns) 737 old_values = [lit(value) for value in old_values] 738 new_values = [lit(value) for value in new_values] 739 740 replacement_mapping = {} 741 for column in columns: 742 expression = Column(None) 743 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 744 if i == 0: 745 expression = F.when(column == old_value, new_value) 746 else: 747 expression = expression.when(column == old_value, new_value) # type: ignore 748 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 749 column.expression.alias_or_name 750 ) 751 752 replacement_mapping = {**all_column_mapping, **replacement_mapping} 753 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 754 new_df = new_df.select(*replacement_columns) 755 return new_df 756 757 @operation(Operation.SELECT) 758 def withColumn(self, colName: str, col: Column) -> DataFrame: 759 col = self._ensure_and_normalize_col(col) 760 existing_col_names = self.expression.named_selects 761 existing_col_index = ( 762 existing_col_names.index(colName) if colName in existing_col_names else None 763 ) 764 if existing_col_index: 765 expression = self.expression.copy() 766 expression.expressions[existing_col_index] = col.expression 767 return self.copy(expression=expression) 768 return self.copy().select(col.alias(colName), append=True) 769 770 @operation(Operation.SELECT) 771 def withColumnRenamed(self, existing: str, new: str): 772 expression = self.expression.copy() 773 existing_columns = [ 774 expression 775 for expression in expression.expressions 776 if expression.alias_or_name == existing 777 ] 778 if not existing_columns: 779 raise ValueError("Tried to rename a column that doesn't exist") 780 for existing_column in existing_columns: 781 if isinstance(existing_column, exp.Column): 782 existing_column.replace(exp.alias_(existing_column, new)) 783 else: 784 existing_column.set("alias", exp.to_identifier(new)) 785 return self.copy(expression=expression) 786 787 @operation(Operation.SELECT) 788 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 789 all_columns = self._get_outer_select_columns(self.expression) 790 drop_cols = self._ensure_and_normalize_cols(cols) 791 new_columns = [ 792 col 793 for col in all_columns 794 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 795 ] 796 return self.copy().select(*new_columns, append=False) 797 798 @operation(Operation.LIMIT) 799 def limit(self, num: int) -> DataFrame: 800 return self.copy(expression=self.expression.limit(num)) 801 802 @operation(Operation.NO_OP) 803 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 804 parameter_list = ensure_list(parameters) 805 parameter_columns = ( 806 self._ensure_list_of_columns(parameter_list) 807 if parameters 808 else Column.ensure_cols([self.sequence_id]) 809 ) 810 return self._hint(name, parameter_columns) 811 812 @operation(Operation.NO_OP) 813 def repartition( 814 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 815 ) -> DataFrame: 816 num_partition_cols = self._ensure_list_of_columns(numPartitions) 817 columns = self._ensure_and_normalize_cols(cols) 818 args = num_partition_cols + columns 819 return self._hint("repartition", args) 820 821 @operation(Operation.NO_OP) 822 def coalesce(self, numPartitions: int) -> DataFrame: 823 num_partitions = Column.ensure_cols([numPartitions]) 824 return self._hint("coalesce", num_partitions) 825 826 @operation(Operation.NO_OP) 827 def cache(self) -> DataFrame: 828 return self._cache(storage_level="MEMORY_AND_DISK") 829 830 @operation(Operation.NO_OP) 831 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 832 """ 833 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 834 """ 835 return self._cache(storageLevel)
48 def __init__( 49 self, 50 spark: SparkSession, 51 expression: exp.Select, 52 branch_id: t.Optional[str] = None, 53 sequence_id: t.Optional[str] = None, 54 last_op: Operation = Operation.INIT, 55 pending_hints: t.Optional[t.List[exp.Expression]] = None, 56 output_expression_container: t.Optional[OutputExpressionContainer] = None, 57 **kwargs, 58 ): 59 self.spark = spark 60 self.expression = expression 61 self.branch_id = branch_id or self.spark._random_branch_id 62 self.sequence_id = sequence_id or self.spark._random_sequence_id 63 self.last_op = last_op 64 self.pending_hints = pending_hints or [] 65 self.output_expression_container = output_expression_container or exp.Select()
85 @property 86 def latest_cte_name(self) -> str: 87 if not self.expression.ctes: 88 from_exp = self.expression.args["from"] 89 if from_exp.alias_or_name: 90 return from_exp.alias_or_name 91 table_alias = from_exp.find(exp.TableAlias) 92 if not table_alias: 93 raise RuntimeError( 94 f"Could not find an alias name for this expression: {self.expression}" 95 ) 96 return table_alias.alias_or_name 97 return self.expression.ctes[-1].alias
299 def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]: 300 from sqlglot.dataframe.sql.session import SparkSession 301 302 dialect = Dialect.get_or_raise(dialect or SparkSession().dialect) 303 304 df = self._resolve_pending_hints() 305 select_expressions = df._get_select_expressions() 306 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 307 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 308 309 for expression_type, select_expression in select_expressions: 310 select_expression = select_expression.transform( 311 replace_id_value, replacement_mapping 312 ).assert_is(exp.Select) 313 if optimize: 314 select_expression = t.cast( 315 exp.Select, self.spark._optimize(select_expression, dialect=dialect) 316 ) 317 318 select_expression = df._replace_cte_names_with_hashes(select_expression) 319 320 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 321 if expression_type == exp.Cache: 322 cache_table_name = df._create_hash_from_expression(select_expression) 323 cache_table = exp.to_table(cache_table_name) 324 original_alias_name = select_expression.args["cte_alias_name"] 325 326 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 327 cache_table_name 328 ) 329 sqlglot.schema.add_table( 330 cache_table_name, 331 { 332 expression.alias_or_name: expression.type.sql(dialect=dialect) 333 for expression in select_expression.expressions 334 }, 335 dialect=dialect, 336 ) 337 338 cache_storage_level = select_expression.args["cache_storage_level"] 339 options = [ 340 exp.Literal.string("storageLevel"), 341 exp.Literal.string(cache_storage_level), 342 ] 343 expression = exp.Cache( 344 this=cache_table, expression=select_expression, lazy=True, options=options 345 ) 346 347 # We will drop the "view" if it exists before running the cache table 348 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 349 elif expression_type == exp.Create: 350 expression = df.output_expression_container.copy() 351 expression.set("expression", select_expression) 352 elif expression_type == exp.Insert: 353 expression = df.output_expression_container.copy() 354 select_without_ctes = select_expression.copy() 355 select_without_ctes.set("with", None) 356 expression.set("expression", select_without_ctes) 357 358 if select_expression.ctes: 359 expression.set("with", exp.With(expressions=select_expression.ctes)) 360 elif expression_type == exp.Select: 361 expression = select_expression 362 else: 363 raise ValueError(f"Invalid expression type: {expression_type}") 364 365 output_expressions.append(expression) 366 367 return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions]
372 @operation(Operation.SELECT) 373 def select(self, *cols, **kwargs) -> DataFrame: 374 cols = self._ensure_and_normalize_cols(cols) 375 kwargs["append"] = kwargs.get("append", False) 376 if self.expression.args.get("joins"): 377 ambiguous_cols = [ 378 col 379 for col in cols 380 if isinstance(col.column_expression, exp.Column) and not col.column_expression.table 381 ] 382 if ambiguous_cols: 383 join_table_identifiers = [ 384 x.this for x in get_tables_from_expression_with_join(self.expression) 385 ] 386 cte_names_in_join = [x.this for x in join_table_identifiers] 387 # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right 388 # and therefore we allow multiple columns with the same name in the result. This matches the behavior 389 # of Spark. 390 resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} 391 for ambiguous_col in ambiguous_cols: 392 ctes_with_column = [ 393 cte 394 for cte in self.expression.ctes 395 if cte.alias_or_name in cte_names_in_join 396 and ambiguous_col.alias_or_name in cte.this.named_selects 397 ] 398 # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, 399 # use the same CTE we used before 400 cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) 401 if cte: 402 resolved_column_position[ambiguous_col] += 1 403 else: 404 cte = ctes_with_column[resolved_column_position[ambiguous_col]] 405 ambiguous_col.expression.set("table", cte.alias_or_name) 406 return self.copy( 407 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 408 )
410 @operation(Operation.NO_OP) 411 def alias(self, name: str, **kwargs) -> DataFrame: 412 new_sequence_id = self.spark._random_sequence_id 413 df = self.copy() 414 for join_hint in df.pending_join_hints: 415 for expression in join_hint.expressions: 416 if expression.alias_or_name == self.sequence_id: 417 expression.set("this", Column.ensure_col(new_sequence_id).expression) 418 df.spark._add_alias_to_mapping(name, new_sequence_id) 419 return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
438 @operation(Operation.FROM) 439 def join( 440 self, 441 other_df: DataFrame, 442 on: t.Union[str, t.List[str], Column, t.List[Column]], 443 how: str = "inner", 444 **kwargs, 445 ) -> DataFrame: 446 other_df = other_df._convert_leaf_to_cte() 447 join_columns = self._ensure_list_of_columns(on) 448 # We will determine actual "join on" expression later so we don't provide it at first 449 join_expression = self.expression.join( 450 other_df.latest_cte_name, join_type=how.replace("_", " ") 451 ) 452 join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) 453 self_columns = self._get_outer_select_columns(join_expression) 454 other_columns = self._get_outer_select_columns(other_df) 455 # Determines the join clause and select columns to be used passed on what type of columns were provided for 456 # the join. The columns returned changes based on how the on expression is provided. 457 if isinstance(join_columns[0].expression, exp.Column): 458 """ 459 Unique characteristics of join on column names only: 460 * The column names are put at the front of the select list 461 * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) 462 """ 463 table_names = [ 464 table.alias_or_name 465 for table in get_tables_from_expression_with_join(join_expression) 466 ] 467 potential_ctes = [ 468 cte 469 for cte in join_expression.ctes 470 if cte.alias_or_name in table_names 471 and cte.alias_or_name != other_df.latest_cte_name 472 ] 473 # Determine the table to reference for the left side of the join by checking each of the left side 474 # tables and see if they have the column being referenced. 475 join_column_pairs = [] 476 for join_column in join_columns: 477 num_matching_ctes = 0 478 for cte in potential_ctes: 479 if join_column.alias_or_name in cte.this.named_selects: 480 left_column = join_column.copy().set_table_name(cte.alias_or_name) 481 right_column = join_column.copy().set_table_name(other_df.latest_cte_name) 482 join_column_pairs.append((left_column, right_column)) 483 num_matching_ctes += 1 484 if num_matching_ctes > 1: 485 raise ValueError( 486 f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." 487 ) 488 elif num_matching_ctes == 0: 489 raise ValueError( 490 f"Column {join_column.alias_or_name} does not exist in any of the tables." 491 ) 492 join_clause = functools.reduce( 493 lambda x, y: x & y, 494 [left_column == right_column for left_column, right_column in join_column_pairs], 495 ) 496 join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] 497 # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list 498 select_column_names = [ 499 ( 500 column.alias_or_name 501 if not isinstance(column.expression.this, exp.Star) 502 else column.sql() 503 ) 504 for column in self_columns + other_columns 505 ] 506 select_column_names = [ 507 column_name 508 for column_name in select_column_names 509 if column_name not in join_column_names 510 ] 511 select_column_names = join_column_names + select_column_names 512 else: 513 """ 514 Unique characteristics of join on expressions: 515 * There is no deduplication of the results. 516 * The left join dataframe columns go first and right come after. No sort preference is given to join columns 517 """ 518 join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) 519 if len(join_columns) > 1: 520 join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] 521 join_clause = join_columns[0] 522 select_column_names = [column.alias_or_name for column in self_columns + other_columns] 523 524 # Update the on expression with the actual join clause to replace the dummy one from before 525 join_expression.args["joins"][-1].set("on", join_clause.expression) 526 new_df = self.copy(expression=join_expression) 527 new_df.pending_join_hints.extend(self.pending_join_hints) 528 new_df.pending_hints.extend(other_df.pending_hints) 529 new_df = new_df.select.__wrapped__(new_df, *select_column_names) 530 return new_df
532 @operation(Operation.ORDER_BY) 533 def orderBy( 534 self, 535 *cols: t.Union[str, Column], 536 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 537 ) -> DataFrame: 538 """ 539 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 540 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 541 is unlikely to come up. 542 """ 543 columns = self._ensure_and_normalize_cols(cols) 544 pre_ordered_col_indexes = [ 545 i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) 546 ] 547 if ascending is None: 548 ascending = [True] * len(columns) 549 elif not isinstance(ascending, list): 550 ascending = [ascending] * len(columns) 551 ascending = [bool(x) for i, x in enumerate(ascending)] 552 assert len(columns) == len( 553 ascending 554 ), "The length of items in ascending must equal the number of columns provided" 555 col_and_ascending = list(zip(columns, ascending)) 556 order_by_columns = [ 557 ( 558 exp.Ordered(this=col.expression, desc=not asc) 559 if i not in pre_ordered_col_indexes 560 else columns[i].column_expression 561 ) 562 for i, (col, asc) in enumerate(col_and_ascending) 563 ] 564 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
532 @operation(Operation.ORDER_BY) 533 def orderBy( 534 self, 535 *cols: t.Union[str, Column], 536 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 537 ) -> DataFrame: 538 """ 539 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 540 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 541 is unlikely to come up. 542 """ 543 columns = self._ensure_and_normalize_cols(cols) 544 pre_ordered_col_indexes = [ 545 i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) 546 ] 547 if ascending is None: 548 ascending = [True] * len(columns) 549 elif not isinstance(ascending, list): 550 ascending = [ascending] * len(columns) 551 ascending = [bool(x) for i, x in enumerate(ascending)] 552 assert len(columns) == len( 553 ascending 554 ), "The length of items in ascending must equal the number of columns provided" 555 col_and_ascending = list(zip(columns, ascending)) 556 order_by_columns = [ 557 ( 558 exp.Ordered(this=col.expression, desc=not asc) 559 if i not in pre_ordered_col_indexes 560 else columns[i].column_expression 561 ) 562 for i, (col, asc) in enumerate(col_and_ascending) 563 ] 564 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
574 @operation(Operation.FROM) 575 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 576 l_columns = self.columns 577 r_columns = other.columns 578 if not allowMissingColumns: 579 l_expressions = l_columns 580 r_expressions = l_columns 581 else: 582 l_expressions = [] 583 r_expressions = [] 584 r_columns_unused = copy(r_columns) 585 for l_column in l_columns: 586 l_expressions.append(l_column) 587 if l_column in r_columns: 588 r_expressions.append(l_column) 589 r_columns_unused.remove(l_column) 590 else: 591 r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) 592 for r_column in r_columns_unused: 593 l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) 594 r_expressions.append(r_column) 595 r_df = ( 596 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 597 ) 598 l_df = self.copy() 599 if allowMissingColumns: 600 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 601 return l_df._set_operation(exp.Union, r_df, False)
619 @operation(Operation.SELECT) 620 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 621 if not subset: 622 return self.distinct() 623 column_names = ensure_list(subset) 624 window = Window.partitionBy(*column_names).orderBy(*column_names) 625 return ( 626 self.copy() 627 .withColumn("row_num", F.row_number().over(window)) 628 .where(F.col("row_num") == F.lit(1)) 629 .drop("row_num") 630 )
632 @operation(Operation.FROM) 633 def dropna( 634 self, 635 how: str = "any", 636 thresh: t.Optional[int] = None, 637 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 638 ) -> DataFrame: 639 minimum_non_null = thresh or 0 # will be determined later if thresh is null 640 new_df = self.copy() 641 all_columns = self._get_outer_select_columns(new_df.expression) 642 if subset: 643 null_check_columns = self._ensure_and_normalize_cols(subset) 644 else: 645 null_check_columns = all_columns 646 if thresh is None: 647 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 648 else: 649 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 650 if minimum_num_nulls > len(null_check_columns): 651 raise RuntimeError( 652 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 653 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 654 ) 655 if_null_checks = [ 656 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 657 ] 658 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 659 num_nulls = nulls_added_together.alias("num_nulls") 660 new_df = new_df.select(num_nulls, append=True) 661 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 662 final_df = filtered_df.select(*all_columns) 663 return final_df
665 @operation(Operation.FROM) 666 def fillna( 667 self, 668 value: t.Union[ColumnLiterals], 669 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 670 ) -> DataFrame: 671 """ 672 Functionality Difference: If you provide a value to replace a null and that type conflicts 673 with the type of the column then PySpark will just ignore your replacement. 674 This will try to cast them to be the same in some cases. So they won't always match. 675 Best to not mix types so make sure replacement is the same type as the column 676 677 Possibility for improvement: Use `typeof` function to get the type of the column 678 and check if it matches the type of the value provided. If not then make it null. 679 """ 680 from sqlglot.dataframe.sql.functions import lit 681 682 values = None 683 columns = None 684 new_df = self.copy() 685 all_columns = self._get_outer_select_columns(new_df.expression) 686 all_column_mapping = {column.alias_or_name: column for column in all_columns} 687 if isinstance(value, dict): 688 values = list(value.values()) 689 columns = self._ensure_and_normalize_cols(list(value)) 690 if not columns: 691 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 692 if not values: 693 values = [value] * len(columns) 694 value_columns = [lit(value) for value in values] 695 696 null_replacement_mapping = { 697 column.alias_or_name: ( 698 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 699 ) 700 for column, value in zip(columns, value_columns) 701 } 702 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 703 null_replacement_columns = [ 704 null_replacement_mapping[column.alias_or_name] for column in all_columns 705 ] 706 new_df = new_df.select(*null_replacement_columns) 707 return new_df
Functionality Difference: If you provide a value to replace a null and that type conflicts with the type of the column then PySpark will just ignore your replacement. This will try to cast them to be the same in some cases. So they won't always match. Best to not mix types so make sure replacement is the same type as the column
Possibility for improvement: Use typeof
function to get the type of the column
and check if it matches the type of the value provided. If not then make it null.
709 @operation(Operation.FROM) 710 def replace( 711 self, 712 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 713 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 714 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 715 ) -> DataFrame: 716 from sqlglot.dataframe.sql.functions import lit 717 718 old_values = None 719 new_df = self.copy() 720 all_columns = self._get_outer_select_columns(new_df.expression) 721 all_column_mapping = {column.alias_or_name: column for column in all_columns} 722 723 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 724 if isinstance(to_replace, dict): 725 old_values = list(to_replace) 726 new_values = list(to_replace.values()) 727 elif not old_values and isinstance(to_replace, list): 728 assert isinstance(value, list), "value must be a list since the replacements are a list" 729 assert len(to_replace) == len( 730 value 731 ), "the replacements and values must be the same length" 732 old_values = to_replace 733 new_values = value 734 else: 735 old_values = [to_replace] * len(columns) 736 new_values = [value] * len(columns) 737 old_values = [lit(value) for value in old_values] 738 new_values = [lit(value) for value in new_values] 739 740 replacement_mapping = {} 741 for column in columns: 742 expression = Column(None) 743 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 744 if i == 0: 745 expression = F.when(column == old_value, new_value) 746 else: 747 expression = expression.when(column == old_value, new_value) # type: ignore 748 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 749 column.expression.alias_or_name 750 ) 751 752 replacement_mapping = {**all_column_mapping, **replacement_mapping} 753 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 754 new_df = new_df.select(*replacement_columns) 755 return new_df
757 @operation(Operation.SELECT) 758 def withColumn(self, colName: str, col: Column) -> DataFrame: 759 col = self._ensure_and_normalize_col(col) 760 existing_col_names = self.expression.named_selects 761 existing_col_index = ( 762 existing_col_names.index(colName) if colName in existing_col_names else None 763 ) 764 if existing_col_index: 765 expression = self.expression.copy() 766 expression.expressions[existing_col_index] = col.expression 767 return self.copy(expression=expression) 768 return self.copy().select(col.alias(colName), append=True)
770 @operation(Operation.SELECT) 771 def withColumnRenamed(self, existing: str, new: str): 772 expression = self.expression.copy() 773 existing_columns = [ 774 expression 775 for expression in expression.expressions 776 if expression.alias_or_name == existing 777 ] 778 if not existing_columns: 779 raise ValueError("Tried to rename a column that doesn't exist") 780 for existing_column in existing_columns: 781 if isinstance(existing_column, exp.Column): 782 existing_column.replace(exp.alias_(existing_column, new)) 783 else: 784 existing_column.set("alias", exp.to_identifier(new)) 785 return self.copy(expression=expression)
787 @operation(Operation.SELECT) 788 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 789 all_columns = self._get_outer_select_columns(self.expression) 790 drop_cols = self._ensure_and_normalize_cols(cols) 791 new_columns = [ 792 col 793 for col in all_columns 794 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 795 ] 796 return self.copy().select(*new_columns, append=False)
802 @operation(Operation.NO_OP) 803 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 804 parameter_list = ensure_list(parameters) 805 parameter_columns = ( 806 self._ensure_list_of_columns(parameter_list) 807 if parameters 808 else Column.ensure_cols([self.sequence_id]) 809 ) 810 return self._hint(name, parameter_columns)
812 @operation(Operation.NO_OP) 813 def repartition( 814 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 815 ) -> DataFrame: 816 num_partition_cols = self._ensure_list_of_columns(numPartitions) 817 columns = self._ensure_and_normalize_cols(cols) 818 args = num_partition_cols + columns 819 return self._hint("repartition", args)
830 @operation(Operation.NO_OP) 831 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 832 """ 833 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 834 """ 835 return self._cache(storageLevel)
Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
14class GroupedData: 15 def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation): 16 self._df = df.copy() 17 self.spark = df.spark 18 self.last_op = last_op 19 self.group_by_cols = group_by_cols 20 21 def _get_function_applied_columns( 22 self, func_name: str, cols: t.Tuple[str, ...] 23 ) -> t.List[Column]: 24 func_name = func_name.lower() 25 return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] 26 27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression) 40 41 def count(self) -> DataFrame: 42 return self.agg(F.count("*").alias("count")) 43 44 def mean(self, *cols: str) -> DataFrame: 45 return self.avg(*cols) 46 47 def avg(self, *cols: str) -> DataFrame: 48 return self.agg(*self._get_function_applied_columns("avg", cols)) 49 50 def max(self, *cols: str) -> DataFrame: 51 return self.agg(*self._get_function_applied_columns("max", cols)) 52 53 def min(self, *cols: str) -> DataFrame: 54 return self.agg(*self._get_function_applied_columns("min", cols)) 55 56 def sum(self, *cols: str) -> DataFrame: 57 return self.agg(*self._get_function_applied_columns("sum", cols)) 58 59 def pivot(self, *cols: str) -> DataFrame: 60 raise NotImplementedError("Sum distinct is not currently implemented")
27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression)
16class Column: 17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 from sqlglot.dataframe.sql.session import SparkSession 19 20 if isinstance(expression, Column): 21 expression = expression.expression # type: ignore 22 elif expression is None or not isinstance(expression, (str, exp.Expression)): 23 expression = self._lit(expression).expression # type: ignore 24 elif not isinstance(expression, exp.Column): 25 expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform( 26 SparkSession().dialect.normalize_identifier, copy=False 27 ) 28 if expression is None: 29 raise ValueError(f"Could not parse {expression}") 30 31 self.expression: exp.Expression = expression # type: ignore 32 33 def __repr__(self): 34 return repr(self.expression) 35 36 def __hash__(self): 37 return hash(self.expression) 38 39 def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore 40 return self.binary_op(exp.EQ, other) 41 42 def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore 43 return self.binary_op(exp.NEQ, other) 44 45 def __gt__(self, other: ColumnOrLiteral) -> Column: 46 return self.binary_op(exp.GT, other) 47 48 def __ge__(self, other: ColumnOrLiteral) -> Column: 49 return self.binary_op(exp.GTE, other) 50 51 def __lt__(self, other: ColumnOrLiteral) -> Column: 52 return self.binary_op(exp.LT, other) 53 54 def __le__(self, other: ColumnOrLiteral) -> Column: 55 return self.binary_op(exp.LTE, other) 56 57 def __and__(self, other: ColumnOrLiteral) -> Column: 58 return self.binary_op(exp.And, other) 59 60 def __or__(self, other: ColumnOrLiteral) -> Column: 61 return self.binary_op(exp.Or, other) 62 63 def __mod__(self, other: ColumnOrLiteral) -> Column: 64 return self.binary_op(exp.Mod, other) 65 66 def __add__(self, other: ColumnOrLiteral) -> Column: 67 return self.binary_op(exp.Add, other) 68 69 def __sub__(self, other: ColumnOrLiteral) -> Column: 70 return self.binary_op(exp.Sub, other) 71 72 def __mul__(self, other: ColumnOrLiteral) -> Column: 73 return self.binary_op(exp.Mul, other) 74 75 def __truediv__(self, other: ColumnOrLiteral) -> Column: 76 return self.binary_op(exp.Div, other) 77 78 def __div__(self, other: ColumnOrLiteral) -> Column: 79 return self.binary_op(exp.Div, other) 80 81 def __neg__(self) -> Column: 82 return self.unary_op(exp.Neg) 83 84 def __radd__(self, other: ColumnOrLiteral) -> Column: 85 return self.inverse_binary_op(exp.Add, other) 86 87 def __rsub__(self, other: ColumnOrLiteral) -> Column: 88 return self.inverse_binary_op(exp.Sub, other) 89 90 def __rmul__(self, other: ColumnOrLiteral) -> Column: 91 return self.inverse_binary_op(exp.Mul, other) 92 93 def __rdiv__(self, other: ColumnOrLiteral) -> Column: 94 return self.inverse_binary_op(exp.Div, other) 95 96 def __rtruediv__(self, other: ColumnOrLiteral) -> Column: 97 return self.inverse_binary_op(exp.Div, other) 98 99 def __rmod__(self, other: ColumnOrLiteral) -> Column: 100 return self.inverse_binary_op(exp.Mod, other) 101 102 def __pow__(self, power: ColumnOrLiteral, modulo=None): 103 return Column(exp.Pow(this=self.expression, expression=Column(power).expression)) 104 105 def __rpow__(self, power: ColumnOrLiteral): 106 return Column(exp.Pow(this=Column(power).expression, expression=self.expression)) 107 108 def __invert__(self): 109 return self.unary_op(exp.Not) 110 111 def __rand__(self, other: ColumnOrLiteral) -> Column: 112 return self.inverse_binary_op(exp.And, other) 113 114 def __ror__(self, other: ColumnOrLiteral) -> Column: 115 return self.inverse_binary_op(exp.Or, other) 116 117 @classmethod 118 def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column: 119 return cls(value) 120 121 @classmethod 122 def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]: 123 return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args] 124 125 @classmethod 126 def _lit(cls, value: ColumnOrLiteral) -> Column: 127 if isinstance(value, dict): 128 columns = [cls._lit(v).alias(k).expression for k, v in value.items()] 129 return cls(exp.Struct(expressions=columns)) 130 return cls(exp.convert(value)) 131 132 @classmethod 133 def invoke_anonymous_function( 134 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 135 ) -> Column: 136 columns = [] if column is None else [cls.ensure_col(column)] 137 column_args = [cls.ensure_col(arg) for arg in args] 138 expressions = [x.expression for x in columns + column_args] 139 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 140 return Column(new_expression) 141 142 @classmethod 143 def invoke_expression_over_column( 144 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 145 ) -> Column: 146 ensured_column = None if column is None else cls.ensure_col(column) 147 ensure_expression_values = { 148 k: ( 149 [Column.ensure_col(x).expression for x in v] 150 if is_iterable(v) 151 else Column.ensure_col(v).expression 152 ) 153 for k, v in kwargs.items() 154 if v is not None 155 } 156 new_expression = ( 157 callable_expression(**ensure_expression_values) 158 if ensured_column is None 159 else callable_expression( 160 this=ensured_column.column_expression, **ensure_expression_values 161 ) 162 ) 163 return Column(new_expression) 164 165 def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 166 return Column( 167 klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs) 168 ) 169 170 def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 171 return Column( 172 klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs) 173 ) 174 175 def unary_op(self, klass: t.Callable, **kwargs) -> Column: 176 return Column(klass(this=self.column_expression, **kwargs)) 177 178 @property 179 def is_alias(self): 180 return isinstance(self.expression, exp.Alias) 181 182 @property 183 def is_column(self): 184 return isinstance(self.expression, exp.Column) 185 186 @property 187 def column_expression(self) -> t.Union[exp.Column, exp.Literal]: 188 return self.expression.unalias() 189 190 @property 191 def alias_or_name(self) -> str: 192 return self.expression.alias_or_name 193 194 @classmethod 195 def ensure_literal(cls, value) -> Column: 196 from sqlglot.dataframe.sql.functions import lit 197 198 if isinstance(value, cls): 199 value = value.expression 200 if not isinstance(value, exp.Literal): 201 return lit(value) 202 return Column(value) 203 204 def copy(self) -> Column: 205 return Column(self.expression.copy()) 206 207 def set_table_name(self, table_name: str, copy=False) -> Column: 208 expression = self.expression.copy() if copy else self.expression 209 expression.set("table", exp.to_identifier(table_name)) 210 return Column(expression) 211 212 def sql(self, **kwargs) -> str: 213 from sqlglot.dataframe.sql.session import SparkSession 214 215 return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) 216 217 def alias(self, name: str) -> Column: 218 from sqlglot.dataframe.sql.session import SparkSession 219 220 dialect = SparkSession().dialect 221 alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect) 222 new_expression = exp.alias_( 223 self.column_expression, 224 alias.this if isinstance(alias, exp.Column) else name, 225 dialect=dialect, 226 ) 227 return Column(new_expression) 228 229 def asc(self) -> Column: 230 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True) 231 return Column(new_expression) 232 233 def desc(self) -> Column: 234 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False) 235 return Column(new_expression) 236 237 asc_nulls_first = asc 238 239 def asc_nulls_last(self) -> Column: 240 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False) 241 return Column(new_expression) 242 243 def desc_nulls_first(self) -> Column: 244 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True) 245 return Column(new_expression) 246 247 desc_nulls_last = desc 248 249 def when(self, condition: Column, value: t.Any) -> Column: 250 from sqlglot.dataframe.sql.functions import when 251 252 column_with_if = when(condition, value) 253 if not isinstance(self.expression, exp.Case): 254 return column_with_if 255 new_column = self.copy() 256 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 257 return new_column 258 259 def otherwise(self, value: t.Any) -> Column: 260 from sqlglot.dataframe.sql.functions import lit 261 262 true_value = value if isinstance(value, Column) else lit(value) 263 new_column = self.copy() 264 new_column.expression.set("default", true_value.column_expression) 265 return new_column 266 267 def isNull(self) -> Column: 268 new_expression = exp.Is(this=self.column_expression, expression=exp.Null()) 269 return Column(new_expression) 270 271 def isNotNull(self) -> Column: 272 new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) 273 return Column(new_expression) 274 275 def cast(self, dataType: t.Union[str, DataType]) -> Column: 276 """ 277 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 278 Sqlglot doesn't currently replicate this class so it only accepts a string 279 """ 280 from sqlglot.dataframe.sql.session import SparkSession 281 282 if isinstance(dataType, DataType): 283 dataType = dataType.simpleString() 284 return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect)) 285 286 def startswith(self, value: t.Union[str, Column]) -> Column: 287 value = self._lit(value) if not isinstance(value, Column) else value 288 return self.invoke_anonymous_function(self, "STARTSWITH", value) 289 290 def endswith(self, value: t.Union[str, Column]) -> Column: 291 value = self._lit(value) if not isinstance(value, Column) else value 292 return self.invoke_anonymous_function(self, "ENDSWITH", value) 293 294 def rlike(self, regexp: str) -> Column: 295 return self.invoke_expression_over_column( 296 column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression 297 ) 298 299 def like(self, other: str): 300 return self.invoke_expression_over_column( 301 self, exp.Like, expression=self._lit(other).expression 302 ) 303 304 def ilike(self, other: str): 305 return self.invoke_expression_over_column( 306 self, exp.ILike, expression=self._lit(other).expression 307 ) 308 309 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 310 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 311 length = self._lit(length) if not isinstance(length, Column) else length 312 return Column.invoke_expression_over_column( 313 self, exp.Substring, start=startPos.expression, length=length.expression 314 ) 315 316 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 317 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 318 expressions = [self._lit(x).expression for x in columns] 319 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore 320 321 def between( 322 self, 323 lowerBound: t.Union[ColumnOrLiteral], 324 upperBound: t.Union[ColumnOrLiteral], 325 ) -> Column: 326 lower_bound_exp = ( 327 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 328 ) 329 upper_bound_exp = ( 330 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 331 ) 332 return Column( 333 exp.Between( 334 this=self.column_expression, 335 low=lower_bound_exp.expression, 336 high=upper_bound_exp.expression, 337 ) 338 ) 339 340 def over(self, window: WindowSpec) -> Column: 341 window_expression = window.expression.copy() 342 window_expression.set("this", self.column_expression) 343 return Column(window_expression)
17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 from sqlglot.dataframe.sql.session import SparkSession 19 20 if isinstance(expression, Column): 21 expression = expression.expression # type: ignore 22 elif expression is None or not isinstance(expression, (str, exp.Expression)): 23 expression = self._lit(expression).expression # type: ignore 24 elif not isinstance(expression, exp.Column): 25 expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform( 26 SparkSession().dialect.normalize_identifier, copy=False 27 ) 28 if expression is None: 29 raise ValueError(f"Could not parse {expression}") 30 31 self.expression: exp.Expression = expression # type: ignore
132 @classmethod 133 def invoke_anonymous_function( 134 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 135 ) -> Column: 136 columns = [] if column is None else [cls.ensure_col(column)] 137 column_args = [cls.ensure_col(arg) for arg in args] 138 expressions = [x.expression for x in columns + column_args] 139 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 140 return Column(new_expression)
142 @classmethod 143 def invoke_expression_over_column( 144 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 145 ) -> Column: 146 ensured_column = None if column is None else cls.ensure_col(column) 147 ensure_expression_values = { 148 k: ( 149 [Column.ensure_col(x).expression for x in v] 150 if is_iterable(v) 151 else Column.ensure_col(v).expression 152 ) 153 for k, v in kwargs.items() 154 if v is not None 155 } 156 new_expression = ( 157 callable_expression(**ensure_expression_values) 158 if ensured_column is None 159 else callable_expression( 160 this=ensured_column.column_expression, **ensure_expression_values 161 ) 162 ) 163 return Column(new_expression)
217 def alias(self, name: str) -> Column: 218 from sqlglot.dataframe.sql.session import SparkSession 219 220 dialect = SparkSession().dialect 221 alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect) 222 new_expression = exp.alias_( 223 self.column_expression, 224 alias.this if isinstance(alias, exp.Column) else name, 225 dialect=dialect, 226 ) 227 return Column(new_expression)
249 def when(self, condition: Column, value: t.Any) -> Column: 250 from sqlglot.dataframe.sql.functions import when 251 252 column_with_if = when(condition, value) 253 if not isinstance(self.expression, exp.Case): 254 return column_with_if 255 new_column = self.copy() 256 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 257 return new_column
275 def cast(self, dataType: t.Union[str, DataType]) -> Column: 276 """ 277 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 278 Sqlglot doesn't currently replicate this class so it only accepts a string 279 """ 280 from sqlglot.dataframe.sql.session import SparkSession 281 282 if isinstance(dataType, DataType): 283 dataType = dataType.simpleString() 284 return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect))
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string
309 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 310 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 311 length = self._lit(length) if not isinstance(length, Column) else length 312 return Column.invoke_expression_over_column( 313 self, exp.Substring, start=startPos.expression, length=length.expression 314 )
316 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 317 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 318 expressions = [self._lit(x).expression for x in columns] 319 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore
321 def between( 322 self, 323 lowerBound: t.Union[ColumnOrLiteral], 324 upperBound: t.Union[ColumnOrLiteral], 325 ) -> Column: 326 lower_bound_exp = ( 327 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 328 ) 329 upper_bound_exp = ( 330 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 331 ) 332 return Column( 333 exp.Between( 334 this=self.column_expression, 335 low=lower_bound_exp.expression, 336 high=upper_bound_exp.expression, 337 ) 338 )
838class DataFrameNaFunctions: 839 def __init__(self, df: DataFrame): 840 self.df = df 841 842 def drop( 843 self, 844 how: str = "any", 845 thresh: t.Optional[int] = None, 846 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 847 ) -> DataFrame: 848 return self.df.dropna(how=how, thresh=thresh, subset=subset) 849 850 def fill( 851 self, 852 value: t.Union[int, bool, float, str, t.Dict[str, t.Any]], 853 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 854 ) -> DataFrame: 855 return self.df.fillna(value=value, subset=subset) 856 857 def replace( 858 self, 859 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 860 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 861 subset: t.Optional[t.Union[str, t.List[str]]] = None, 862 ) -> DataFrame: 863 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
857 def replace( 858 self, 859 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 860 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 861 subset: t.Optional[t.Union[str, t.List[str]]] = None, 862 ) -> DataFrame: 863 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
15class Window: 16 _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808 17 _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807 18 _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG) 19 _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG) 20 21 unboundedPreceding: int = _JAVA_MIN_LONG 22 23 unboundedFollowing: int = _JAVA_MAX_LONG 24 25 currentRow: int = 0 26 27 @classmethod 28 def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 29 return WindowSpec().partitionBy(*cols) 30 31 @classmethod 32 def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 33 return WindowSpec().orderBy(*cols) 34 35 @classmethod 36 def rowsBetween(cls, start: int, end: int) -> WindowSpec: 37 return WindowSpec().rowsBetween(start, end) 38 39 @classmethod 40 def rangeBetween(cls, start: int, end: int) -> WindowSpec: 41 return WindowSpec().rangeBetween(start, end)
44class WindowSpec: 45 def __init__(self, expression: exp.Expression = exp.Window()): 46 self.expression = expression 47 48 def copy(self): 49 return WindowSpec(self.expression.copy()) 50 51 def sql(self, **kwargs) -> str: 52 from sqlglot.dataframe.sql.session import SparkSession 53 54 return self.expression.sql(dialect=SparkSession().dialect, **kwargs) 55 56 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 57 from sqlglot.dataframe.sql.column import Column 58 59 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 60 expressions = [Column.ensure_col(x).expression for x in cols] 61 window_spec = self.copy() 62 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 63 partition_by_expressions.extend(expressions) 64 window_spec.expression.set("partition_by", partition_by_expressions) 65 return window_spec 66 67 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 68 from sqlglot.dataframe.sql.column import Column 69 70 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 71 expressions = [Column.ensure_col(x).expression for x in cols] 72 window_spec = self.copy() 73 if window_spec.expression.args.get("order") is None: 74 window_spec.expression.set("order", exp.Order(expressions=[])) 75 order_by = window_spec.expression.args["order"].expressions 76 order_by.extend(expressions) 77 window_spec.expression.args["order"].set("expressions", order_by) 78 return window_spec 79 80 def _calc_start_end( 81 self, start: int, end: int 82 ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: 83 kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = { 84 "start_side": None, 85 "end_side": None, 86 } 87 if start == Window.currentRow: 88 kwargs["start"] = "CURRENT ROW" 89 else: 90 kwargs = { 91 **kwargs, 92 **{ 93 "start_side": "PRECEDING", 94 "start": ( 95 "UNBOUNDED" 96 if start <= Window.unboundedPreceding 97 else F.lit(start).expression 98 ), 99 }, 100 } 101 if end == Window.currentRow: 102 kwargs["end"] = "CURRENT ROW" 103 else: 104 kwargs = { 105 **kwargs, 106 **{ 107 "end_side": "FOLLOWING", 108 "end": ( 109 "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression 110 ), 111 }, 112 } 113 return kwargs 114 115 def rowsBetween(self, start: int, end: int) -> WindowSpec: 116 window_spec = self.copy() 117 spec = self._calc_start_end(start, end) 118 spec["kind"] = "ROWS" 119 window_spec.expression.set( 120 "spec", 121 exp.WindowSpec( 122 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 123 ), 124 ) 125 return window_spec 126 127 def rangeBetween(self, start: int, end: int) -> WindowSpec: 128 window_spec = self.copy() 129 spec = self._calc_start_end(start, end) 130 spec["kind"] = "RANGE" 131 window_spec.expression.set( 132 "spec", 133 exp.WindowSpec( 134 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 135 ), 136 ) 137 return window_spec
56 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 57 from sqlglot.dataframe.sql.column import Column 58 59 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 60 expressions = [Column.ensure_col(x).expression for x in cols] 61 window_spec = self.copy() 62 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 63 partition_by_expressions.extend(expressions) 64 window_spec.expression.set("partition_by", partition_by_expressions) 65 return window_spec
67 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 68 from sqlglot.dataframe.sql.column import Column 69 70 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 71 expressions = [Column.ensure_col(x).expression for x in cols] 72 window_spec = self.copy() 73 if window_spec.expression.args.get("order") is None: 74 window_spec.expression.set("order", exp.Order(expressions=[])) 75 order_by = window_spec.expression.args["order"].expressions 76 order_by.extend(expressions) 77 window_spec.expression.args["order"].set("expressions", order_by) 78 return window_spec
115 def rowsBetween(self, start: int, end: int) -> WindowSpec: 116 window_spec = self.copy() 117 spec = self._calc_start_end(start, end) 118 spec["kind"] = "ROWS" 119 window_spec.expression.set( 120 "spec", 121 exp.WindowSpec( 122 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 123 ), 124 ) 125 return window_spec
127 def rangeBetween(self, start: int, end: int) -> WindowSpec: 128 window_spec = self.copy() 129 spec = self._calc_start_end(start, end) 130 spec["kind"] = "RANGE" 131 window_spec.expression.set( 132 "spec", 133 exp.WindowSpec( 134 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 135 ), 136 ) 137 return window_spec
15class DataFrameReader: 16 def __init__(self, spark: SparkSession): 17 self.spark = spark 18 19 def table(self, tableName: str) -> DataFrame: 20 from sqlglot.dataframe.sql.dataframe import DataFrame 21 from sqlglot.dataframe.sql.session import SparkSession 22 23 sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) 24 25 return DataFrame( 26 self.spark, 27 exp.Select() 28 .from_( 29 exp.to_table(tableName, dialect=SparkSession().dialect).transform( 30 SparkSession().dialect.normalize_identifier 31 ) 32 ) 33 .select( 34 *( 35 column 36 for column in sqlglot.schema.column_names( 37 tableName, dialect=SparkSession().dialect 38 ) 39 ) 40 ), 41 )
19 def table(self, tableName: str) -> DataFrame: 20 from sqlglot.dataframe.sql.dataframe import DataFrame 21 from sqlglot.dataframe.sql.session import SparkSession 22 23 sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) 24 25 return DataFrame( 26 self.spark, 27 exp.Select() 28 .from_( 29 exp.to_table(tableName, dialect=SparkSession().dialect).transform( 30 SparkSession().dialect.normalize_identifier 31 ) 32 ) 33 .select( 34 *( 35 column 36 for column in sqlglot.schema.column_names( 37 tableName, dialect=SparkSession().dialect 38 ) 39 ) 40 ), 41 )
44class DataFrameWriter: 45 def __init__( 46 self, 47 df: DataFrame, 48 spark: t.Optional[SparkSession] = None, 49 mode: t.Optional[str] = None, 50 by_name: bool = False, 51 ): 52 self._df = df 53 self._spark = spark or df.spark 54 self._mode = mode 55 self._by_name = by_name 56 57 def copy(self, **kwargs) -> DataFrameWriter: 58 return DataFrameWriter( 59 **{ 60 k[1:] if k.startswith("_") else k: v 61 for k, v in object_to_dict(self, **kwargs).items() 62 } 63 ) 64 65 def sql(self, **kwargs) -> t.List[str]: 66 return self._df.sql(**kwargs) 67 68 def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter: 69 return self.copy(_mode=saveMode) 70 71 @property 72 def byName(self): 73 return self.copy(by_name=True) 74 75 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 76 from sqlglot.dataframe.sql.session import SparkSession 77 78 output_expression_container = exp.Insert( 79 **{ 80 "this": exp.to_table(tableName), 81 "overwrite": overwrite, 82 } 83 ) 84 df = self._df.copy(output_expression_container=output_expression_container) 85 if self._by_name: 86 columns = sqlglot.schema.column_names( 87 tableName, only_visible=True, dialect=SparkSession().dialect 88 ) 89 df = df._convert_leaf_to_cte().select(*columns) 90 91 return self.copy(_df=df) 92 93 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 94 if format is not None: 95 raise NotImplementedError("Providing Format in the save as table is not supported") 96 exists, replace, mode = None, None, mode or str(self._mode) 97 if mode == "append": 98 return self.insertInto(name) 99 if mode == "ignore": 100 exists = True 101 if mode == "overwrite": 102 replace = True 103 output_expression_container = exp.Create( 104 this=exp.to_table(name), 105 kind="TABLE", 106 exists=exists, 107 replace=replace, 108 ) 109 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
75 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 76 from sqlglot.dataframe.sql.session import SparkSession 77 78 output_expression_container = exp.Insert( 79 **{ 80 "this": exp.to_table(tableName), 81 "overwrite": overwrite, 82 } 83 ) 84 df = self._df.copy(output_expression_container=output_expression_container) 85 if self._by_name: 86 columns = sqlglot.schema.column_names( 87 tableName, only_visible=True, dialect=SparkSession().dialect 88 ) 89 df = df._convert_leaf_to_cte().select(*columns) 90 91 return self.copy(_df=df)
93 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 94 if format is not None: 95 raise NotImplementedError("Providing Format in the save as table is not supported") 96 exists, replace, mode = None, None, mode or str(self._mode) 97 if mode == "append": 98 return self.insertInto(name) 99 if mode == "ignore": 100 exists = True 101 if mode == "overwrite": 102 replace = True 103 output_expression_container = exp.Create( 104 this=exp.to_table(name), 105 kind="TABLE", 106 exists=exists, 107 replace=replace, 108 ) 109 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))