Edit on GitHub

sqlglot.dataframe.sql

 1from sqlglot.dataframe.sql.column import Column
 2from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions
 3from sqlglot.dataframe.sql.group import GroupedData
 4from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter
 5from sqlglot.dataframe.sql.session import SparkSession
 6from sqlglot.dataframe.sql.window import Window, WindowSpec
 7
 8__all__ = [
 9    "SparkSession",
10    "DataFrame",
11    "GroupedData",
12    "Column",
13    "DataFrameNaFunctions",
14    "Window",
15    "WindowSpec",
16    "DataFrameReader",
17    "DataFrameWriter",
18]
class SparkSession:
 21class SparkSession:
 22    DEFAULT_DIALECT = "spark"
 23    _instance = None
 24
 25    def __init__(self):
 26        if not hasattr(self, "known_ids"):
 27            self.known_ids = set()
 28            self.known_branch_ids = set()
 29            self.known_sequence_ids = set()
 30            self.name_to_sequence_id_mapping = defaultdict(list)
 31            self.incrementing_id = 1
 32            self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)
 33
 34    def __new__(cls, *args, **kwargs) -> SparkSession:
 35        if cls._instance is None:
 36            cls._instance = super().__new__(cls)
 37        return cls._instance
 38
 39    @property
 40    def read(self) -> DataFrameReader:
 41        return DataFrameReader(self)
 42
 43    def table(self, tableName: str) -> DataFrame:
 44        return self.read.table(tableName)
 45
 46    def createDataFrame(
 47        self,
 48        data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
 49        schema: t.Optional[SchemaInput] = None,
 50        samplingRatio: t.Optional[float] = None,
 51        verifySchema: bool = False,
 52    ) -> DataFrame:
 53        from sqlglot.dataframe.sql.dataframe import DataFrame
 54
 55        if samplingRatio is not None or verifySchema:
 56            raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
 57        if schema is not None and (
 58            not isinstance(schema, (StructType, str, list))
 59            or (isinstance(schema, list) and not isinstance(schema[0], str))
 60        ):
 61            raise NotImplementedError("Only schema of either list or string of list supported")
 62        if not data:
 63            raise ValueError("Must provide data to create into a DataFrame")
 64
 65        column_mapping: t.Dict[str, t.Optional[str]]
 66        if schema is not None:
 67            column_mapping = get_column_mapping_from_schema_input(schema)
 68        elif isinstance(data[0], dict):
 69            column_mapping = {col_name.strip(): None for col_name in data[0]}
 70        else:
 71            column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
 72
 73        data_expressions = [
 74            exp.tuple_(
 75                *map(
 76                    lambda x: F.lit(x).expression,
 77                    row if not isinstance(row, dict) else row.values(),
 78                )
 79            )
 80            for row in data
 81        ]
 82
 83        sel_columns = [
 84            (
 85                F.col(name).cast(data_type).alias(name).expression
 86                if data_type is not None
 87                else F.col(name).expression
 88            )
 89            for name, data_type in column_mapping.items()
 90        ]
 91
 92        select_kwargs = {
 93            "expressions": sel_columns,
 94            "from": exp.From(
 95                this=exp.Values(
 96                    expressions=data_expressions,
 97                    alias=exp.TableAlias(
 98                        this=exp.to_identifier(self._auto_incrementing_name),
 99                        columns=[exp.to_identifier(col_name) for col_name in column_mapping],
100                    ),
101                ),
102            ),
103        }
104
105        sel_expression = exp.Select(**select_kwargs)
106        return DataFrame(self, sel_expression)
107
108    def sql(self, sqlQuery: str) -> DataFrame:
109        expression = sqlglot.parse_one(sqlQuery, read=self.dialect)
110        if isinstance(expression, exp.Select):
111            df = DataFrame(self, expression)
112            df = df._convert_leaf_to_cte()
113        elif isinstance(expression, (exp.Create, exp.Insert)):
114            select_expression = expression.expression.copy()
115            if isinstance(expression, exp.Insert):
116                select_expression.set("with", expression.args.get("with"))
117                expression.set("with", None)
118            del expression.args["expression"]
119            df = DataFrame(self, select_expression, output_expression_container=expression)  # type: ignore
120            df = df._convert_leaf_to_cte()
121        else:
122            raise ValueError(
123                "Unknown expression type provided in the SQL. Please create an issue with the SQL."
124            )
125        return df
126
127    @property
128    def _auto_incrementing_name(self) -> str:
129        name = f"a{self.incrementing_id}"
130        self.incrementing_id += 1
131        return name
132
133    @property
134    def _random_branch_id(self) -> str:
135        id = self._random_id
136        self.known_branch_ids.add(id)
137        return id
138
139    @property
140    def _random_sequence_id(self):
141        id = self._random_id
142        self.known_sequence_ids.add(id)
143        return id
144
145    @property
146    def _random_id(self) -> str:
147        id = "r" + uuid.uuid4().hex
148        self.known_ids.add(id)
149        return id
150
151    @property
152    def _join_hint_names(self) -> t.Set[str]:
153        return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
154
155    def _add_alias_to_mapping(self, name: str, sequence_id: str):
156        self.name_to_sequence_id_mapping[name].append(sequence_id)
157
158    class Builder:
159        SQLFRAME_DIALECT_KEY = "sqlframe.dialect"
160
161        def __init__(self):
162            self.dialect = "spark"
163
164        def __getattr__(self, item) -> SparkSession.Builder:
165            return self
166
167        def __call__(self, *args, **kwargs):
168            return self
169
170        def config(
171            self,
172            key: t.Optional[str] = None,
173            value: t.Optional[t.Any] = None,
174            *,
175            map: t.Optional[t.Dict[str, t.Any]] = None,
176            **kwargs: t.Any,
177        ) -> SparkSession.Builder:
178            if key == self.SQLFRAME_DIALECT_KEY:
179                self.dialect = value
180            elif map and self.SQLFRAME_DIALECT_KEY in map:
181                self.dialect = map[self.SQLFRAME_DIALECT_KEY]
182            return self
183
184        def getOrCreate(self) -> SparkSession:
185            spark = SparkSession()
186            spark.dialect = Dialect.get_or_raise(self.dialect)
187            return spark
188
189    @classproperty
190    def builder(cls) -> Builder:
191        return cls.Builder()
DEFAULT_DIALECT = 'spark'
read: DataFrameReader
39    @property
40    def read(self) -> DataFrameReader:
41        return DataFrameReader(self)
def table(self, tableName: str) -> DataFrame:
43    def table(self, tableName: str) -> DataFrame:
44        return self.read.table(tableName)
def createDataFrame( self, data: Sequence[Union[Dict[str, <MagicMock id='140403570517536'>], List[<MagicMock id='140403570517536'>], Tuple]], schema: Optional[<MagicMock id='140403570444320'>] = None, samplingRatio: Optional[float] = None, verifySchema: bool = False) -> DataFrame:
 46    def createDataFrame(
 47        self,
 48        data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
 49        schema: t.Optional[SchemaInput] = None,
 50        samplingRatio: t.Optional[float] = None,
 51        verifySchema: bool = False,
 52    ) -> DataFrame:
 53        from sqlglot.dataframe.sql.dataframe import DataFrame
 54
 55        if samplingRatio is not None or verifySchema:
 56            raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
 57        if schema is not None and (
 58            not isinstance(schema, (StructType, str, list))
 59            or (isinstance(schema, list) and not isinstance(schema[0], str))
 60        ):
 61            raise NotImplementedError("Only schema of either list or string of list supported")
 62        if not data:
 63            raise ValueError("Must provide data to create into a DataFrame")
 64
 65        column_mapping: t.Dict[str, t.Optional[str]]
 66        if schema is not None:
 67            column_mapping = get_column_mapping_from_schema_input(schema)
 68        elif isinstance(data[0], dict):
 69            column_mapping = {col_name.strip(): None for col_name in data[0]}
 70        else:
 71            column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
 72
 73        data_expressions = [
 74            exp.tuple_(
 75                *map(
 76                    lambda x: F.lit(x).expression,
 77                    row if not isinstance(row, dict) else row.values(),
 78                )
 79            )
 80            for row in data
 81        ]
 82
 83        sel_columns = [
 84            (
 85                F.col(name).cast(data_type).alias(name).expression
 86                if data_type is not None
 87                else F.col(name).expression
 88            )
 89            for name, data_type in column_mapping.items()
 90        ]
 91
 92        select_kwargs = {
 93            "expressions": sel_columns,
 94            "from": exp.From(
 95                this=exp.Values(
 96                    expressions=data_expressions,
 97                    alias=exp.TableAlias(
 98                        this=exp.to_identifier(self._auto_incrementing_name),
 99                        columns=[exp.to_identifier(col_name) for col_name in column_mapping],
100                    ),
101                ),
102            ),
103        }
104
105        sel_expression = exp.Select(**select_kwargs)
106        return DataFrame(self, sel_expression)
def sql(self, sqlQuery: str) -> DataFrame:
108    def sql(self, sqlQuery: str) -> DataFrame:
109        expression = sqlglot.parse_one(sqlQuery, read=self.dialect)
110        if isinstance(expression, exp.Select):
111            df = DataFrame(self, expression)
112            df = df._convert_leaf_to_cte()
113        elif isinstance(expression, (exp.Create, exp.Insert)):
114            select_expression = expression.expression.copy()
115            if isinstance(expression, exp.Insert):
116                select_expression.set("with", expression.args.get("with"))
117                expression.set("with", None)
118            del expression.args["expression"]
119            df = DataFrame(self, select_expression, output_expression_container=expression)  # type: ignore
120            df = df._convert_leaf_to_cte()
121        else:
122            raise ValueError(
123                "Unknown expression type provided in the SQL. Please create an issue with the SQL."
124            )
125        return df
builder: SparkSession.Builder
189    @classproperty
190    def builder(cls) -> Builder:
191        return cls.Builder()
class SparkSession.Builder:
158    class Builder:
159        SQLFRAME_DIALECT_KEY = "sqlframe.dialect"
160
161        def __init__(self):
162            self.dialect = "spark"
163
164        def __getattr__(self, item) -> SparkSession.Builder:
165            return self
166
167        def __call__(self, *args, **kwargs):
168            return self
169
170        def config(
171            self,
172            key: t.Optional[str] = None,
173            value: t.Optional[t.Any] = None,
174            *,
175            map: t.Optional[t.Dict[str, t.Any]] = None,
176            **kwargs: t.Any,
177        ) -> SparkSession.Builder:
178            if key == self.SQLFRAME_DIALECT_KEY:
179                self.dialect = value
180            elif map and self.SQLFRAME_DIALECT_KEY in map:
181                self.dialect = map[self.SQLFRAME_DIALECT_KEY]
182            return self
183
184        def getOrCreate(self) -> SparkSession:
185            spark = SparkSession()
186            spark.dialect = Dialect.get_or_raise(self.dialect)
187            return spark
SQLFRAME_DIALECT_KEY = 'sqlframe.dialect'
dialect
def config( self, key: Optional[str] = None, value: Optional[Any] = None, *, map: Optional[Dict[str, Any]] = None, **kwargs: Any) -> SparkSession.Builder:
170        def config(
171            self,
172            key: t.Optional[str] = None,
173            value: t.Optional[t.Any] = None,
174            *,
175            map: t.Optional[t.Dict[str, t.Any]] = None,
176            **kwargs: t.Any,
177        ) -> SparkSession.Builder:
178            if key == self.SQLFRAME_DIALECT_KEY:
179                self.dialect = value
180            elif map and self.SQLFRAME_DIALECT_KEY in map:
181                self.dialect = map[self.SQLFRAME_DIALECT_KEY]
182            return self
def getOrCreate(self) -> SparkSession:
184        def getOrCreate(self) -> SparkSession:
185            spark = SparkSession()
186            spark.dialect = Dialect.get_or_raise(self.dialect)
187            return spark
class DataFrame:
 49class DataFrame:
 50    def __init__(
 51        self,
 52        spark: SparkSession,
 53        expression: exp.Select,
 54        branch_id: t.Optional[str] = None,
 55        sequence_id: t.Optional[str] = None,
 56        last_op: Operation = Operation.INIT,
 57        pending_hints: t.Optional[t.List[exp.Expression]] = None,
 58        output_expression_container: t.Optional[OutputExpressionContainer] = None,
 59        **kwargs,
 60    ):
 61        self.spark = spark
 62        self.expression = expression
 63        self.branch_id = branch_id or self.spark._random_branch_id
 64        self.sequence_id = sequence_id or self.spark._random_sequence_id
 65        self.last_op = last_op
 66        self.pending_hints = pending_hints or []
 67        self.output_expression_container = output_expression_container or exp.Select()
 68
 69    def __getattr__(self, column_name: str) -> Column:
 70        return self[column_name]
 71
 72    def __getitem__(self, column_name: str) -> Column:
 73        column_name = f"{self.branch_id}.{column_name}"
 74        return Column(column_name)
 75
 76    def __copy__(self):
 77        return self.copy()
 78
 79    @property
 80    def sparkSession(self):
 81        return self.spark
 82
 83    @property
 84    def write(self):
 85        return DataFrameWriter(self)
 86
 87    @property
 88    def latest_cte_name(self) -> str:
 89        if not self.expression.ctes:
 90            from_exp = self.expression.args["from"]
 91            if from_exp.alias_or_name:
 92                return from_exp.alias_or_name
 93            table_alias = from_exp.find(exp.TableAlias)
 94            if not table_alias:
 95                raise RuntimeError(
 96                    f"Could not find an alias name for this expression: {self.expression}"
 97                )
 98            return table_alias.alias_or_name
 99        return self.expression.ctes[-1].alias
100
101    @property
102    def pending_join_hints(self):
103        return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
104
105    @property
106    def pending_partition_hints(self):
107        return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
108
109    @property
110    def columns(self) -> t.List[str]:
111        return self.expression.named_selects
112
113    @property
114    def na(self) -> DataFrameNaFunctions:
115        return DataFrameNaFunctions(self)
116
117    def _replace_cte_names_with_hashes(self, expression: exp.Select):
118        replacement_mapping = {}
119        for cte in expression.ctes:
120            old_name_id = cte.args["alias"].this
121            new_hashed_id = exp.to_identifier(
122                self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
123            )
124            replacement_mapping[old_name_id] = new_hashed_id
125            expression = expression.transform(replace_id_value, replacement_mapping)
126        return expression
127
128    def _create_cte_from_expression(
129        self,
130        expression: exp.Expression,
131        branch_id: t.Optional[str] = None,
132        sequence_id: t.Optional[str] = None,
133        **kwargs,
134    ) -> t.Tuple[exp.CTE, str]:
135        name = self._create_hash_from_expression(expression)
136        expression_to_cte = expression.copy()
137        expression_to_cte.set("with", None)
138        cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
139        cte.set("branch_id", branch_id or self.branch_id)
140        cte.set("sequence_id", sequence_id or self.sequence_id)
141        return cte, name
142
143    @t.overload
144    def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ...
145
146    @t.overload
147    def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ...
148
149    def _ensure_list_of_columns(self, cols):
150        return Column.ensure_cols(ensure_list(cols))
151
152    def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None):
153        cols = self._ensure_list_of_columns(cols)
154        normalize(self.spark, expression or self.expression, cols)
155        return cols
156
157    def _ensure_and_normalize_col(self, col):
158        col = Column.ensure_col(col)
159        normalize(self.spark, self.expression, col)
160        return col
161
162    def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame:
163        df = self._resolve_pending_hints()
164        sequence_id = sequence_id or df.sequence_id
165        expression = df.expression.copy()
166        cte_expression, cte_name = df._create_cte_from_expression(
167            expression=expression, sequence_id=sequence_id
168        )
169        new_expression = df._add_ctes_to_expression(
170            exp.Select(), expression.ctes + [cte_expression]
171        )
172        sel_columns = df._get_outer_select_columns(cte_expression)
173        new_expression = new_expression.from_(cte_name).select(
174            *[x.alias_or_name for x in sel_columns]
175        )
176        return df.copy(expression=new_expression, sequence_id=sequence_id)
177
178    def _resolve_pending_hints(self) -> DataFrame:
179        df = self.copy()
180        if not self.pending_hints:
181            return df
182        expression = df.expression
183        hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
184        for hint in df.pending_partition_hints:
185            hint_expression.append("expressions", hint)
186            df.pending_hints.remove(hint)
187
188        join_aliases = {
189            join_table.alias_or_name
190            for join_table in get_tables_from_expression_with_join(expression)
191        }
192        if join_aliases:
193            for hint in df.pending_join_hints:
194                for sequence_id_expression in hint.expressions:
195                    sequence_id_or_name = sequence_id_expression.alias_or_name
196                    sequence_ids_to_match = [sequence_id_or_name]
197                    if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
198                        sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
199                            sequence_id_or_name
200                        ]
201                    matching_ctes = [
202                        cte
203                        for cte in reversed(expression.ctes)
204                        if cte.args["sequence_id"] in sequence_ids_to_match
205                    ]
206                    for matching_cte in matching_ctes:
207                        if matching_cte.alias_or_name in join_aliases:
208                            sequence_id_expression.set("this", matching_cte.args["alias"].this)
209                            df.pending_hints.remove(hint)
210                            break
211                hint_expression.append("expressions", hint)
212        if hint_expression.expressions:
213            expression.set("hint", hint_expression)
214        return df
215
216    def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
217        hint_name = hint_name.upper()
218        hint_expression = (
219            exp.JoinHint(
220                this=hint_name,
221                expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
222            )
223            if hint_name in JOIN_HINTS
224            else exp.Anonymous(
225                this=hint_name, expressions=[parameter.expression for parameter in args]
226            )
227        )
228        new_df = self.copy()
229        new_df.pending_hints.append(hint_expression)
230        return new_df
231
232    def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool):
233        other_df = other._convert_leaf_to_cte()
234        base_expression = self.expression.copy()
235        base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
236        all_ctes = base_expression.ctes
237        other_df.expression.set("with", None)
238        base_expression.set("with", None)
239        operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
240        operation.set("with", exp.With(expressions=all_ctes))
241        return self.copy(expression=operation)._convert_leaf_to_cte()
242
243    def _cache(self, storage_level: str):
244        df = self._convert_leaf_to_cte()
245        df.expression.ctes[-1].set("cache_storage_level", storage_level)
246        return df
247
248    @classmethod
249    def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
250        expression = expression.copy()
251        with_expression = expression.args.get("with")
252        if with_expression:
253            existing_ctes = with_expression.expressions
254            existsing_cte_names = {x.alias_or_name for x in existing_ctes}
255            for cte in ctes:
256                if cte.alias_or_name not in existsing_cte_names:
257                    existing_ctes.append(cte)
258        else:
259            existing_ctes = ctes
260        expression.set("with", exp.With(expressions=existing_ctes))
261        return expression
262
263    @classmethod
264    def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
265        expression = item.expression if isinstance(item, DataFrame) else item
266        return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]
267
268    @classmethod
269    def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
270        from sqlglot.dataframe.sql.session import SparkSession
271
272        value = expression.sql(dialect=SparkSession().dialect).encode("utf-8")
273        return f"t{zlib.crc32(value)}"[:6]
274
275    def _get_select_expressions(
276        self,
277    ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
278        select_expressions: t.List[
279            t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
280        ] = []
281        main_select_ctes: t.List[exp.CTE] = []
282        for cte in self.expression.ctes:
283            cache_storage_level = cte.args.get("cache_storage_level")
284            if cache_storage_level:
285                select_expression = cte.this.copy()
286                select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
287                select_expression.set("cte_alias_name", cte.alias_or_name)
288                select_expression.set("cache_storage_level", cache_storage_level)
289                select_expressions.append((exp.Cache, select_expression))
290            else:
291                main_select_ctes.append(cte)
292        main_select = self.expression.copy()
293        if main_select_ctes:
294            main_select.set("with", exp.With(expressions=main_select_ctes))
295        expression_select_pair = (type(self.output_expression_container), main_select)
296        select_expressions.append(expression_select_pair)  # type: ignore
297        return select_expressions
298
299    def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]:
300        from sqlglot.dataframe.sql.session import SparkSession
301
302        dialect = Dialect.get_or_raise(dialect or SparkSession().dialect)
303
304        df = self._resolve_pending_hints()
305        select_expressions = df._get_select_expressions()
306        output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
307        replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
308
309        for expression_type, select_expression in select_expressions:
310            select_expression = select_expression.transform(replace_id_value, replacement_mapping)
311            if optimize:
312                quote_identifiers(select_expression, dialect=dialect)
313                select_expression = t.cast(
314                    exp.Select, optimize_func(select_expression, dialect=dialect)
315                )
316
317            select_expression = df._replace_cte_names_with_hashes(select_expression)
318
319            expression: t.Union[exp.Select, exp.Cache, exp.Drop]
320            if expression_type == exp.Cache:
321                cache_table_name = df._create_hash_from_expression(select_expression)
322                cache_table = exp.to_table(cache_table_name)
323                original_alias_name = select_expression.args["cte_alias_name"]
324
325                replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(  # type: ignore
326                    cache_table_name
327                )
328                sqlglot.schema.add_table(
329                    cache_table_name,
330                    {
331                        expression.alias_or_name: expression.type.sql(dialect=dialect)
332                        for expression in select_expression.expressions
333                    },
334                    dialect=dialect,
335                )
336
337                cache_storage_level = select_expression.args["cache_storage_level"]
338                options = [
339                    exp.Literal.string("storageLevel"),
340                    exp.Literal.string(cache_storage_level),
341                ]
342                expression = exp.Cache(
343                    this=cache_table, expression=select_expression, lazy=True, options=options
344                )
345
346                # We will drop the "view" if it exists before running the cache table
347                output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
348            elif expression_type == exp.Create:
349                expression = df.output_expression_container.copy()
350                expression.set("expression", select_expression)
351            elif expression_type == exp.Insert:
352                expression = df.output_expression_container.copy()
353                select_without_ctes = select_expression.copy()
354                select_without_ctes.set("with", None)
355                expression.set("expression", select_without_ctes)
356
357                if select_expression.ctes:
358                    expression.set("with", exp.With(expressions=select_expression.ctes))
359            elif expression_type == exp.Select:
360                expression = select_expression
361            else:
362                raise ValueError(f"Invalid expression type: {expression_type}")
363
364            output_expressions.append(expression)
365
366        return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions]
367
368    def copy(self, **kwargs) -> DataFrame:
369        return DataFrame(**object_to_dict(self, **kwargs))
370
371    @operation(Operation.SELECT)
372    def select(self, *cols, **kwargs) -> DataFrame:
373        cols = self._ensure_and_normalize_cols(cols)
374        kwargs["append"] = kwargs.get("append", False)
375        if self.expression.args.get("joins"):
376            ambiguous_cols = [
377                col
378                for col in cols
379                if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
380            ]
381            if ambiguous_cols:
382                join_table_identifiers = [
383                    x.this for x in get_tables_from_expression_with_join(self.expression)
384                ]
385                cte_names_in_join = [x.this for x in join_table_identifiers]
386                # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
387                # and therefore we allow multiple columns with the same name in the result. This matches the behavior
388                # of Spark.
389                resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
390                for ambiguous_col in ambiguous_cols:
391                    ctes_with_column = [
392                        cte
393                        for cte in self.expression.ctes
394                        if cte.alias_or_name in cte_names_in_join
395                        and ambiguous_col.alias_or_name in cte.this.named_selects
396                    ]
397                    # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
398                    # use the same CTE we used before
399                    cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
400                    if cte:
401                        resolved_column_position[ambiguous_col] += 1
402                    else:
403                        cte = ctes_with_column[resolved_column_position[ambiguous_col]]
404                    ambiguous_col.expression.set("table", cte.alias_or_name)
405        return self.copy(
406            expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
407        )
408
409    @operation(Operation.NO_OP)
410    def alias(self, name: str, **kwargs) -> DataFrame:
411        new_sequence_id = self.spark._random_sequence_id
412        df = self.copy()
413        for join_hint in df.pending_join_hints:
414            for expression in join_hint.expressions:
415                if expression.alias_or_name == self.sequence_id:
416                    expression.set("this", Column.ensure_col(new_sequence_id).expression)
417        df.spark._add_alias_to_mapping(name, new_sequence_id)
418        return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
419
420    @operation(Operation.WHERE)
421    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
422        col = self._ensure_and_normalize_col(column)
423        return self.copy(expression=self.expression.where(col.expression))
424
425    filter = where
426
427    @operation(Operation.GROUP_BY)
428    def groupBy(self, *cols, **kwargs) -> GroupedData:
429        columns = self._ensure_and_normalize_cols(cols)
430        return GroupedData(self, columns, self.last_op)
431
432    @operation(Operation.SELECT)
433    def agg(self, *exprs, **kwargs) -> DataFrame:
434        cols = self._ensure_and_normalize_cols(exprs)
435        return self.groupBy().agg(*cols)
436
437    @operation(Operation.FROM)
438    def join(
439        self,
440        other_df: DataFrame,
441        on: t.Union[str, t.List[str], Column, t.List[Column]],
442        how: str = "inner",
443        **kwargs,
444    ) -> DataFrame:
445        other_df = other_df._convert_leaf_to_cte()
446        join_columns = self._ensure_list_of_columns(on)
447        # We will determine actual "join on" expression later so we don't provide it at first
448        join_expression = self.expression.join(
449            other_df.latest_cte_name, join_type=how.replace("_", " ")
450        )
451        join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
452        self_columns = self._get_outer_select_columns(join_expression)
453        other_columns = self._get_outer_select_columns(other_df)
454        # Determines the join clause and select columns to be used passed on what type of columns were provided for
455        # the join. The columns returned changes based on how the on expression is provided.
456        if isinstance(join_columns[0].expression, exp.Column):
457            """
458            Unique characteristics of join on column names only:
459            * The column names are put at the front of the select list
460            * The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
461            """
462            table_names = [
463                table.alias_or_name
464                for table in get_tables_from_expression_with_join(join_expression)
465            ]
466            potential_ctes = [
467                cte
468                for cte in join_expression.ctes
469                if cte.alias_or_name in table_names
470                and cte.alias_or_name != other_df.latest_cte_name
471            ]
472            # Determine the table to reference for the left side of the join by checking each of the left side
473            # tables and see if they have the column being referenced.
474            join_column_pairs = []
475            for join_column in join_columns:
476                num_matching_ctes = 0
477                for cte in potential_ctes:
478                    if join_column.alias_or_name in cte.this.named_selects:
479                        left_column = join_column.copy().set_table_name(cte.alias_or_name)
480                        right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
481                        join_column_pairs.append((left_column, right_column))
482                        num_matching_ctes += 1
483                if num_matching_ctes > 1:
484                    raise ValueError(
485                        f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
486                    )
487                elif num_matching_ctes == 0:
488                    raise ValueError(
489                        f"Column {join_column.alias_or_name} does not exist in any of the tables."
490                    )
491            join_clause = functools.reduce(
492                lambda x, y: x & y,
493                [left_column == right_column for left_column, right_column in join_column_pairs],
494            )
495            join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
496            # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
497            select_column_names = [
498                (
499                    column.alias_or_name
500                    if not isinstance(column.expression.this, exp.Star)
501                    else column.sql()
502                )
503                for column in self_columns + other_columns
504            ]
505            select_column_names = [
506                column_name
507                for column_name in select_column_names
508                if column_name not in join_column_names
509            ]
510            select_column_names = join_column_names + select_column_names
511        else:
512            """
513            Unique characteristics of join on expressions:
514            * There is no deduplication of the results.
515            * The left join dataframe columns go first and right come after. No sort preference is given to join columns
516            """
517            join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
518            if len(join_columns) > 1:
519                join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
520            join_clause = join_columns[0]
521            select_column_names = [column.alias_or_name for column in self_columns + other_columns]
522
523        # Update the on expression with the actual join clause to replace the dummy one from before
524        join_expression.args["joins"][-1].set("on", join_clause.expression)
525        new_df = self.copy(expression=join_expression)
526        new_df.pending_join_hints.extend(self.pending_join_hints)
527        new_df.pending_hints.extend(other_df.pending_hints)
528        new_df = new_df.select.__wrapped__(new_df, *select_column_names)
529        return new_df
530
531    @operation(Operation.ORDER_BY)
532    def orderBy(
533        self,
534        *cols: t.Union[str, Column],
535        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
536    ) -> DataFrame:
537        """
538        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
539        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
540        is unlikely to come up.
541        """
542        columns = self._ensure_and_normalize_cols(cols)
543        pre_ordered_col_indexes = [
544            i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered)
545        ]
546        if ascending is None:
547            ascending = [True] * len(columns)
548        elif not isinstance(ascending, list):
549            ascending = [ascending] * len(columns)
550        ascending = [bool(x) for i, x in enumerate(ascending)]
551        assert len(columns) == len(
552            ascending
553        ), "The length of items in ascending must equal the number of columns provided"
554        col_and_ascending = list(zip(columns, ascending))
555        order_by_columns = [
556            (
557                exp.Ordered(this=col.expression, desc=not asc)
558                if i not in pre_ordered_col_indexes
559                else columns[i].column_expression
560            )
561            for i, (col, asc) in enumerate(col_and_ascending)
562        ]
563        return self.copy(expression=self.expression.order_by(*order_by_columns))
564
565    sort = orderBy
566
567    @operation(Operation.FROM)
568    def union(self, other: DataFrame) -> DataFrame:
569        return self._set_operation(exp.Union, other, False)
570
571    unionAll = union
572
573    @operation(Operation.FROM)
574    def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
575        l_columns = self.columns
576        r_columns = other.columns
577        if not allowMissingColumns:
578            l_expressions = l_columns
579            r_expressions = l_columns
580        else:
581            l_expressions = []
582            r_expressions = []
583            r_columns_unused = copy(r_columns)
584            for l_column in l_columns:
585                l_expressions.append(l_column)
586                if l_column in r_columns:
587                    r_expressions.append(l_column)
588                    r_columns_unused.remove(l_column)
589                else:
590                    r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
591            for r_column in r_columns_unused:
592                l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
593                r_expressions.append(r_column)
594        r_df = (
595            other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
596        )
597        l_df = self.copy()
598        if allowMissingColumns:
599            l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
600        return l_df._set_operation(exp.Union, r_df, False)
601
602    @operation(Operation.FROM)
603    def intersect(self, other: DataFrame) -> DataFrame:
604        return self._set_operation(exp.Intersect, other, True)
605
606    @operation(Operation.FROM)
607    def intersectAll(self, other: DataFrame) -> DataFrame:
608        return self._set_operation(exp.Intersect, other, False)
609
610    @operation(Operation.FROM)
611    def exceptAll(self, other: DataFrame) -> DataFrame:
612        return self._set_operation(exp.Except, other, False)
613
614    @operation(Operation.SELECT)
615    def distinct(self) -> DataFrame:
616        return self.copy(expression=self.expression.distinct())
617
618    @operation(Operation.SELECT)
619    def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
620        if not subset:
621            return self.distinct()
622        column_names = ensure_list(subset)
623        window = Window.partitionBy(*column_names).orderBy(*column_names)
624        return (
625            self.copy()
626            .withColumn("row_num", F.row_number().over(window))
627            .where(F.col("row_num") == F.lit(1))
628            .drop("row_num")
629        )
630
631    @operation(Operation.FROM)
632    def dropna(
633        self,
634        how: str = "any",
635        thresh: t.Optional[int] = None,
636        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
637    ) -> DataFrame:
638        minimum_non_null = thresh or 0  # will be determined later if thresh is null
639        new_df = self.copy()
640        all_columns = self._get_outer_select_columns(new_df.expression)
641        if subset:
642            null_check_columns = self._ensure_and_normalize_cols(subset)
643        else:
644            null_check_columns = all_columns
645        if thresh is None:
646            minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
647        else:
648            minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
649        if minimum_num_nulls > len(null_check_columns):
650            raise RuntimeError(
651                f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
652                f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
653            )
654        if_null_checks = [
655            F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
656        ]
657        nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
658        num_nulls = nulls_added_together.alias("num_nulls")
659        new_df = new_df.select(num_nulls, append=True)
660        filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
661        final_df = filtered_df.select(*all_columns)
662        return final_df
663
664    @operation(Operation.FROM)
665    def fillna(
666        self,
667        value: t.Union[ColumnLiterals],
668        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
669    ) -> DataFrame:
670        """
671        Functionality Difference: If you provide a value to replace a null and that type conflicts
672        with the type of the column then PySpark will just ignore your replacement.
673        This will try to cast them to be the same in some cases. So they won't always match.
674        Best to not mix types so make sure replacement is the same type as the column
675
676        Possibility for improvement: Use `typeof` function to get the type of the column
677        and check if it matches the type of the value provided. If not then make it null.
678        """
679        from sqlglot.dataframe.sql.functions import lit
680
681        values = None
682        columns = None
683        new_df = self.copy()
684        all_columns = self._get_outer_select_columns(new_df.expression)
685        all_column_mapping = {column.alias_or_name: column for column in all_columns}
686        if isinstance(value, dict):
687            values = list(value.values())
688            columns = self._ensure_and_normalize_cols(list(value))
689        if not columns:
690            columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
691        if not values:
692            values = [value] * len(columns)
693        value_columns = [lit(value) for value in values]
694
695        null_replacement_mapping = {
696            column.alias_or_name: (
697                F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
698            )
699            for column, value in zip(columns, value_columns)
700        }
701        null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
702        null_replacement_columns = [
703            null_replacement_mapping[column.alias_or_name] for column in all_columns
704        ]
705        new_df = new_df.select(*null_replacement_columns)
706        return new_df
707
708    @operation(Operation.FROM)
709    def replace(
710        self,
711        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
712        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
713        subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
714    ) -> DataFrame:
715        from sqlglot.dataframe.sql.functions import lit
716
717        old_values = None
718        new_df = self.copy()
719        all_columns = self._get_outer_select_columns(new_df.expression)
720        all_column_mapping = {column.alias_or_name: column for column in all_columns}
721
722        columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
723        if isinstance(to_replace, dict):
724            old_values = list(to_replace)
725            new_values = list(to_replace.values())
726        elif not old_values and isinstance(to_replace, list):
727            assert isinstance(value, list), "value must be a list since the replacements are a list"
728            assert len(to_replace) == len(
729                value
730            ), "the replacements and values must be the same length"
731            old_values = to_replace
732            new_values = value
733        else:
734            old_values = [to_replace] * len(columns)
735            new_values = [value] * len(columns)
736        old_values = [lit(value) for value in old_values]
737        new_values = [lit(value) for value in new_values]
738
739        replacement_mapping = {}
740        for column in columns:
741            expression = Column(None)
742            for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
743                if i == 0:
744                    expression = F.when(column == old_value, new_value)
745                else:
746                    expression = expression.when(column == old_value, new_value)  # type: ignore
747            replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
748                column.expression.alias_or_name
749            )
750
751        replacement_mapping = {**all_column_mapping, **replacement_mapping}
752        replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
753        new_df = new_df.select(*replacement_columns)
754        return new_df
755
756    @operation(Operation.SELECT)
757    def withColumn(self, colName: str, col: Column) -> DataFrame:
758        col = self._ensure_and_normalize_col(col)
759        existing_col_names = self.expression.named_selects
760        existing_col_index = (
761            existing_col_names.index(colName) if colName in existing_col_names else None
762        )
763        if existing_col_index:
764            expression = self.expression.copy()
765            expression.expressions[existing_col_index] = col.expression
766            return self.copy(expression=expression)
767        return self.copy().select(col.alias(colName), append=True)
768
769    @operation(Operation.SELECT)
770    def withColumnRenamed(self, existing: str, new: str):
771        expression = self.expression.copy()
772        existing_columns = [
773            expression
774            for expression in expression.expressions
775            if expression.alias_or_name == existing
776        ]
777        if not existing_columns:
778            raise ValueError("Tried to rename a column that doesn't exist")
779        for existing_column in existing_columns:
780            if isinstance(existing_column, exp.Column):
781                existing_column.replace(exp.alias_(existing_column, new))
782            else:
783                existing_column.set("alias", exp.to_identifier(new))
784        return self.copy(expression=expression)
785
786    @operation(Operation.SELECT)
787    def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
788        all_columns = self._get_outer_select_columns(self.expression)
789        drop_cols = self._ensure_and_normalize_cols(cols)
790        new_columns = [
791            col
792            for col in all_columns
793            if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
794        ]
795        return self.copy().select(*new_columns, append=False)
796
797    @operation(Operation.LIMIT)
798    def limit(self, num: int) -> DataFrame:
799        return self.copy(expression=self.expression.limit(num))
800
801    @operation(Operation.NO_OP)
802    def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
803        parameter_list = ensure_list(parameters)
804        parameter_columns = (
805            self._ensure_list_of_columns(parameter_list)
806            if parameters
807            else Column.ensure_cols([self.sequence_id])
808        )
809        return self._hint(name, parameter_columns)
810
811    @operation(Operation.NO_OP)
812    def repartition(
813        self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
814    ) -> DataFrame:
815        num_partition_cols = self._ensure_list_of_columns(numPartitions)
816        columns = self._ensure_and_normalize_cols(cols)
817        args = num_partition_cols + columns
818        return self._hint("repartition", args)
819
820    @operation(Operation.NO_OP)
821    def coalesce(self, numPartitions: int) -> DataFrame:
822        num_partitions = Column.ensure_cols([numPartitions])
823        return self._hint("coalesce", num_partitions)
824
825    @operation(Operation.NO_OP)
826    def cache(self) -> DataFrame:
827        return self._cache(storage_level="MEMORY_AND_DISK")
828
829    @operation(Operation.NO_OP)
830    def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
831        """
832        Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
833        """
834        return self._cache(storageLevel)
DataFrame( spark: <MagicMock id='140403574041776'>, expression: sqlglot.expressions.Select, branch_id: Optional[str] = None, sequence_id: Optional[str] = None, last_op: sqlglot.dataframe.sql.operations.Operation = <Operation.INIT: -1>, pending_hints: Optional[List[sqlglot.expressions.Expression]] = None, output_expression_container: Optional[<MagicMock id='140403575161424'>] = None, **kwargs)
50    def __init__(
51        self,
52        spark: SparkSession,
53        expression: exp.Select,
54        branch_id: t.Optional[str] = None,
55        sequence_id: t.Optional[str] = None,
56        last_op: Operation = Operation.INIT,
57        pending_hints: t.Optional[t.List[exp.Expression]] = None,
58        output_expression_container: t.Optional[OutputExpressionContainer] = None,
59        **kwargs,
60    ):
61        self.spark = spark
62        self.expression = expression
63        self.branch_id = branch_id or self.spark._random_branch_id
64        self.sequence_id = sequence_id or self.spark._random_sequence_id
65        self.last_op = last_op
66        self.pending_hints = pending_hints or []
67        self.output_expression_container = output_expression_container or exp.Select()
spark
expression
branch_id
sequence_id
last_op
pending_hints
output_expression_container
sparkSession
79    @property
80    def sparkSession(self):
81        return self.spark
write
83    @property
84    def write(self):
85        return DataFrameWriter(self)
latest_cte_name: str
87    @property
88    def latest_cte_name(self) -> str:
89        if not self.expression.ctes:
90            from_exp = self.expression.args["from"]
91            if from_exp.alias_or_name:
92                return from_exp.alias_or_name
93            table_alias = from_exp.find(exp.TableAlias)
94            if not table_alias:
95                raise RuntimeError(
96                    f"Could not find an alias name for this expression: {self.expression}"
97                )
98            return table_alias.alias_or_name
99        return self.expression.ctes[-1].alias
pending_join_hints
101    @property
102    def pending_join_hints(self):
103        return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
pending_partition_hints
105    @property
106    def pending_partition_hints(self):
107        return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
columns: List[str]
109    @property
110    def columns(self) -> t.List[str]:
111        return self.expression.named_selects
na: DataFrameNaFunctions
113    @property
114    def na(self) -> DataFrameNaFunctions:
115        return DataFrameNaFunctions(self)
def sql( self, dialect: <MagicMock id='140403570880528'> = None, optimize: bool = True, **kwargs) -> List[str]:
299    def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]:
300        from sqlglot.dataframe.sql.session import SparkSession
301
302        dialect = Dialect.get_or_raise(dialect or SparkSession().dialect)
303
304        df = self._resolve_pending_hints()
305        select_expressions = df._get_select_expressions()
306        output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
307        replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
308
309        for expression_type, select_expression in select_expressions:
310            select_expression = select_expression.transform(replace_id_value, replacement_mapping)
311            if optimize:
312                quote_identifiers(select_expression, dialect=dialect)
313                select_expression = t.cast(
314                    exp.Select, optimize_func(select_expression, dialect=dialect)
315                )
316
317            select_expression = df._replace_cte_names_with_hashes(select_expression)
318
319            expression: t.Union[exp.Select, exp.Cache, exp.Drop]
320            if expression_type == exp.Cache:
321                cache_table_name = df._create_hash_from_expression(select_expression)
322                cache_table = exp.to_table(cache_table_name)
323                original_alias_name = select_expression.args["cte_alias_name"]
324
325                replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(  # type: ignore
326                    cache_table_name
327                )
328                sqlglot.schema.add_table(
329                    cache_table_name,
330                    {
331                        expression.alias_or_name: expression.type.sql(dialect=dialect)
332                        for expression in select_expression.expressions
333                    },
334                    dialect=dialect,
335                )
336
337                cache_storage_level = select_expression.args["cache_storage_level"]
338                options = [
339                    exp.Literal.string("storageLevel"),
340                    exp.Literal.string(cache_storage_level),
341                ]
342                expression = exp.Cache(
343                    this=cache_table, expression=select_expression, lazy=True, options=options
344                )
345
346                # We will drop the "view" if it exists before running the cache table
347                output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
348            elif expression_type == exp.Create:
349                expression = df.output_expression_container.copy()
350                expression.set("expression", select_expression)
351            elif expression_type == exp.Insert:
352                expression = df.output_expression_container.copy()
353                select_without_ctes = select_expression.copy()
354                select_without_ctes.set("with", None)
355                expression.set("expression", select_without_ctes)
356
357                if select_expression.ctes:
358                    expression.set("with", exp.With(expressions=select_expression.ctes))
359            elif expression_type == exp.Select:
360                expression = select_expression
361            else:
362                raise ValueError(f"Invalid expression type: {expression_type}")
363
364            output_expressions.append(expression)
365
366        return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions]
def copy(self, **kwargs) -> DataFrame:
368    def copy(self, **kwargs) -> DataFrame:
369        return DataFrame(**object_to_dict(self, **kwargs))
@operation(Operation.SELECT)
def select(self, *cols, **kwargs) -> DataFrame:
371    @operation(Operation.SELECT)
372    def select(self, *cols, **kwargs) -> DataFrame:
373        cols = self._ensure_and_normalize_cols(cols)
374        kwargs["append"] = kwargs.get("append", False)
375        if self.expression.args.get("joins"):
376            ambiguous_cols = [
377                col
378                for col in cols
379                if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
380            ]
381            if ambiguous_cols:
382                join_table_identifiers = [
383                    x.this for x in get_tables_from_expression_with_join(self.expression)
384                ]
385                cte_names_in_join = [x.this for x in join_table_identifiers]
386                # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
387                # and therefore we allow multiple columns with the same name in the result. This matches the behavior
388                # of Spark.
389                resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
390                for ambiguous_col in ambiguous_cols:
391                    ctes_with_column = [
392                        cte
393                        for cte in self.expression.ctes
394                        if cte.alias_or_name in cte_names_in_join
395                        and ambiguous_col.alias_or_name in cte.this.named_selects
396                    ]
397                    # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
398                    # use the same CTE we used before
399                    cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
400                    if cte:
401                        resolved_column_position[ambiguous_col] += 1
402                    else:
403                        cte = ctes_with_column[resolved_column_position[ambiguous_col]]
404                    ambiguous_col.expression.set("table", cte.alias_or_name)
405        return self.copy(
406            expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
407        )
@operation(Operation.NO_OP)
def alias(self, name: str, **kwargs) -> DataFrame:
409    @operation(Operation.NO_OP)
410    def alias(self, name: str, **kwargs) -> DataFrame:
411        new_sequence_id = self.spark._random_sequence_id
412        df = self.copy()
413        for join_hint in df.pending_join_hints:
414            for expression in join_hint.expressions:
415                if expression.alias_or_name == self.sequence_id:
416                    expression.set("this", Column.ensure_col(new_sequence_id).expression)
417        df.spark._add_alias_to_mapping(name, new_sequence_id)
418        return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
@operation(Operation.WHERE)
def where( self, column: Union[Column, bool], **kwargs) -> DataFrame:
420    @operation(Operation.WHERE)
421    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
422        col = self._ensure_and_normalize_col(column)
423        return self.copy(expression=self.expression.where(col.expression))
@operation(Operation.WHERE)
def filter( self, column: Union[Column, bool], **kwargs) -> DataFrame:
420    @operation(Operation.WHERE)
421    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
422        col = self._ensure_and_normalize_col(column)
423        return self.copy(expression=self.expression.where(col.expression))
@operation(Operation.GROUP_BY)
def groupBy(self, *cols, **kwargs) -> GroupedData:
427    @operation(Operation.GROUP_BY)
428    def groupBy(self, *cols, **kwargs) -> GroupedData:
429        columns = self._ensure_and_normalize_cols(cols)
430        return GroupedData(self, columns, self.last_op)
@operation(Operation.SELECT)
def agg(self, *exprs, **kwargs) -> DataFrame:
432    @operation(Operation.SELECT)
433    def agg(self, *exprs, **kwargs) -> DataFrame:
434        cols = self._ensure_and_normalize_cols(exprs)
435        return self.groupBy().agg(*cols)
@operation(Operation.FROM)
def join( self, other_df: DataFrame, on: Union[str, List[str], Column, List[Column]], how: str = 'inner', **kwargs) -> DataFrame:
437    @operation(Operation.FROM)
438    def join(
439        self,
440        other_df: DataFrame,
441        on: t.Union[str, t.List[str], Column, t.List[Column]],
442        how: str = "inner",
443        **kwargs,
444    ) -> DataFrame:
445        other_df = other_df._convert_leaf_to_cte()
446        join_columns = self._ensure_list_of_columns(on)
447        # We will determine actual "join on" expression later so we don't provide it at first
448        join_expression = self.expression.join(
449            other_df.latest_cte_name, join_type=how.replace("_", " ")
450        )
451        join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
452        self_columns = self._get_outer_select_columns(join_expression)
453        other_columns = self._get_outer_select_columns(other_df)
454        # Determines the join clause and select columns to be used passed on what type of columns were provided for
455        # the join. The columns returned changes based on how the on expression is provided.
456        if isinstance(join_columns[0].expression, exp.Column):
457            """
458            Unique characteristics of join on column names only:
459            * The column names are put at the front of the select list
460            * The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
461            """
462            table_names = [
463                table.alias_or_name
464                for table in get_tables_from_expression_with_join(join_expression)
465            ]
466            potential_ctes = [
467                cte
468                for cte in join_expression.ctes
469                if cte.alias_or_name in table_names
470                and cte.alias_or_name != other_df.latest_cte_name
471            ]
472            # Determine the table to reference for the left side of the join by checking each of the left side
473            # tables and see if they have the column being referenced.
474            join_column_pairs = []
475            for join_column in join_columns:
476                num_matching_ctes = 0
477                for cte in potential_ctes:
478                    if join_column.alias_or_name in cte.this.named_selects:
479                        left_column = join_column.copy().set_table_name(cte.alias_or_name)
480                        right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
481                        join_column_pairs.append((left_column, right_column))
482                        num_matching_ctes += 1
483                if num_matching_ctes > 1:
484                    raise ValueError(
485                        f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
486                    )
487                elif num_matching_ctes == 0:
488                    raise ValueError(
489                        f"Column {join_column.alias_or_name} does not exist in any of the tables."
490                    )
491            join_clause = functools.reduce(
492                lambda x, y: x & y,
493                [left_column == right_column for left_column, right_column in join_column_pairs],
494            )
495            join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
496            # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
497            select_column_names = [
498                (
499                    column.alias_or_name
500                    if not isinstance(column.expression.this, exp.Star)
501                    else column.sql()
502                )
503                for column in self_columns + other_columns
504            ]
505            select_column_names = [
506                column_name
507                for column_name in select_column_names
508                if column_name not in join_column_names
509            ]
510            select_column_names = join_column_names + select_column_names
511        else:
512            """
513            Unique characteristics of join on expressions:
514            * There is no deduplication of the results.
515            * The left join dataframe columns go first and right come after. No sort preference is given to join columns
516            """
517            join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
518            if len(join_columns) > 1:
519                join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
520            join_clause = join_columns[0]
521            select_column_names = [column.alias_or_name for column in self_columns + other_columns]
522
523        # Update the on expression with the actual join clause to replace the dummy one from before
524        join_expression.args["joins"][-1].set("on", join_clause.expression)
525        new_df = self.copy(expression=join_expression)
526        new_df.pending_join_hints.extend(self.pending_join_hints)
527        new_df.pending_hints.extend(other_df.pending_hints)
528        new_df = new_df.select.__wrapped__(new_df, *select_column_names)
529        return new_df
@operation(Operation.ORDER_BY)
def orderBy( self, *cols: Union[str, Column], ascending: Union[Any, List[Any], NoneType] = None) -> DataFrame:
531    @operation(Operation.ORDER_BY)
532    def orderBy(
533        self,
534        *cols: t.Union[str, Column],
535        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
536    ) -> DataFrame:
537        """
538        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
539        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
540        is unlikely to come up.
541        """
542        columns = self._ensure_and_normalize_cols(cols)
543        pre_ordered_col_indexes = [
544            i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered)
545        ]
546        if ascending is None:
547            ascending = [True] * len(columns)
548        elif not isinstance(ascending, list):
549            ascending = [ascending] * len(columns)
550        ascending = [bool(x) for i, x in enumerate(ascending)]
551        assert len(columns) == len(
552            ascending
553        ), "The length of items in ascending must equal the number of columns provided"
554        col_and_ascending = list(zip(columns, ascending))
555        order_by_columns = [
556            (
557                exp.Ordered(this=col.expression, desc=not asc)
558                if i not in pre_ordered_col_indexes
559                else columns[i].column_expression
560            )
561            for i, (col, asc) in enumerate(col_and_ascending)
562        ]
563        return self.copy(expression=self.expression.order_by(*order_by_columns))

This implementation lets any ordered columns take priority over whatever is provided in ascending. Spark has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this is unlikely to come up.

@operation(Operation.ORDER_BY)
def sort( self, *cols: Union[str, Column], ascending: Union[Any, List[Any], NoneType] = None) -> DataFrame:
531    @operation(Operation.ORDER_BY)
532    def orderBy(
533        self,
534        *cols: t.Union[str, Column],
535        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
536    ) -> DataFrame:
537        """
538        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
539        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
540        is unlikely to come up.
541        """
542        columns = self._ensure_and_normalize_cols(cols)
543        pre_ordered_col_indexes = [
544            i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered)
545        ]
546        if ascending is None:
547            ascending = [True] * len(columns)
548        elif not isinstance(ascending, list):
549            ascending = [ascending] * len(columns)
550        ascending = [bool(x) for i, x in enumerate(ascending)]
551        assert len(columns) == len(
552            ascending
553        ), "The length of items in ascending must equal the number of columns provided"
554        col_and_ascending = list(zip(columns, ascending))
555        order_by_columns = [
556            (
557                exp.Ordered(this=col.expression, desc=not asc)
558                if i not in pre_ordered_col_indexes
559                else columns[i].column_expression
560            )
561            for i, (col, asc) in enumerate(col_and_ascending)
562        ]
563        return self.copy(expression=self.expression.order_by(*order_by_columns))

This implementation lets any ordered columns take priority over whatever is provided in ascending. Spark has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this is unlikely to come up.

@operation(Operation.FROM)
def union( self, other: DataFrame) -> DataFrame:
567    @operation(Operation.FROM)
568    def union(self, other: DataFrame) -> DataFrame:
569        return self._set_operation(exp.Union, other, False)
@operation(Operation.FROM)
def unionAll( self, other: DataFrame) -> DataFrame:
567    @operation(Operation.FROM)
568    def union(self, other: DataFrame) -> DataFrame:
569        return self._set_operation(exp.Union, other, False)
@operation(Operation.FROM)
def unionByName( self, other: DataFrame, allowMissingColumns: bool = False):
573    @operation(Operation.FROM)
574    def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
575        l_columns = self.columns
576        r_columns = other.columns
577        if not allowMissingColumns:
578            l_expressions = l_columns
579            r_expressions = l_columns
580        else:
581            l_expressions = []
582            r_expressions = []
583            r_columns_unused = copy(r_columns)
584            for l_column in l_columns:
585                l_expressions.append(l_column)
586                if l_column in r_columns:
587                    r_expressions.append(l_column)
588                    r_columns_unused.remove(l_column)
589                else:
590                    r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
591            for r_column in r_columns_unused:
592                l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
593                r_expressions.append(r_column)
594        r_df = (
595            other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
596        )
597        l_df = self.copy()
598        if allowMissingColumns:
599            l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
600        return l_df._set_operation(exp.Union, r_df, False)
@operation(Operation.FROM)
def intersect( self, other: DataFrame) -> DataFrame:
602    @operation(Operation.FROM)
603    def intersect(self, other: DataFrame) -> DataFrame:
604        return self._set_operation(exp.Intersect, other, True)
@operation(Operation.FROM)
def intersectAll( self, other: DataFrame) -> DataFrame:
606    @operation(Operation.FROM)
607    def intersectAll(self, other: DataFrame) -> DataFrame:
608        return self._set_operation(exp.Intersect, other, False)
@operation(Operation.FROM)
def exceptAll( self, other: DataFrame) -> DataFrame:
610    @operation(Operation.FROM)
611    def exceptAll(self, other: DataFrame) -> DataFrame:
612        return self._set_operation(exp.Except, other, False)
@operation(Operation.SELECT)
def distinct(self) -> DataFrame:
614    @operation(Operation.SELECT)
615    def distinct(self) -> DataFrame:
616        return self.copy(expression=self.expression.distinct())
@operation(Operation.SELECT)
def dropDuplicates(self, subset: Optional[List[str]] = None):
618    @operation(Operation.SELECT)
619    def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
620        if not subset:
621            return self.distinct()
622        column_names = ensure_list(subset)
623        window = Window.partitionBy(*column_names).orderBy(*column_names)
624        return (
625            self.copy()
626            .withColumn("row_num", F.row_number().over(window))
627            .where(F.col("row_num") == F.lit(1))
628            .drop("row_num")
629        )
@operation(Operation.FROM)
def dropna( self, how: str = 'any', thresh: Optional[int] = None, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> DataFrame:
631    @operation(Operation.FROM)
632    def dropna(
633        self,
634        how: str = "any",
635        thresh: t.Optional[int] = None,
636        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
637    ) -> DataFrame:
638        minimum_non_null = thresh or 0  # will be determined later if thresh is null
639        new_df = self.copy()
640        all_columns = self._get_outer_select_columns(new_df.expression)
641        if subset:
642            null_check_columns = self._ensure_and_normalize_cols(subset)
643        else:
644            null_check_columns = all_columns
645        if thresh is None:
646            minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
647        else:
648            minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
649        if minimum_num_nulls > len(null_check_columns):
650            raise RuntimeError(
651                f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
652                f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
653            )
654        if_null_checks = [
655            F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
656        ]
657        nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
658        num_nulls = nulls_added_together.alias("num_nulls")
659        new_df = new_df.select(num_nulls, append=True)
660        filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
661        final_df = filtered_df.select(*all_columns)
662        return final_df
@operation(Operation.FROM)
def fillna( self, value: <MagicMock id='140403569103712'>, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> DataFrame:
664    @operation(Operation.FROM)
665    def fillna(
666        self,
667        value: t.Union[ColumnLiterals],
668        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
669    ) -> DataFrame:
670        """
671        Functionality Difference: If you provide a value to replace a null and that type conflicts
672        with the type of the column then PySpark will just ignore your replacement.
673        This will try to cast them to be the same in some cases. So they won't always match.
674        Best to not mix types so make sure replacement is the same type as the column
675
676        Possibility for improvement: Use `typeof` function to get the type of the column
677        and check if it matches the type of the value provided. If not then make it null.
678        """
679        from sqlglot.dataframe.sql.functions import lit
680
681        values = None
682        columns = None
683        new_df = self.copy()
684        all_columns = self._get_outer_select_columns(new_df.expression)
685        all_column_mapping = {column.alias_or_name: column for column in all_columns}
686        if isinstance(value, dict):
687            values = list(value.values())
688            columns = self._ensure_and_normalize_cols(list(value))
689        if not columns:
690            columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
691        if not values:
692            values = [value] * len(columns)
693        value_columns = [lit(value) for value in values]
694
695        null_replacement_mapping = {
696            column.alias_or_name: (
697                F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
698            )
699            for column, value in zip(columns, value_columns)
700        }
701        null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
702        null_replacement_columns = [
703            null_replacement_mapping[column.alias_or_name] for column in all_columns
704        ]
705        new_df = new_df.select(*null_replacement_columns)
706        return new_df

Functionality Difference: If you provide a value to replace a null and that type conflicts with the type of the column then PySpark will just ignore your replacement. This will try to cast them to be the same in some cases. So they won't always match. Best to not mix types so make sure replacement is the same type as the column

Possibility for improvement: Use typeof function to get the type of the column and check if it matches the type of the value provided. If not then make it null.

@operation(Operation.FROM)
def replace( self, to_replace: Union[bool, int, float, str, List, Dict], value: Union[bool, int, float, str, List, NoneType] = None, subset: Union[Collection[<MagicMock id='140403569303536'>], <MagicMock id='140403569303536'>, NoneType] = None) -> DataFrame:
708    @operation(Operation.FROM)
709    def replace(
710        self,
711        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
712        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
713        subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
714    ) -> DataFrame:
715        from sqlglot.dataframe.sql.functions import lit
716
717        old_values = None
718        new_df = self.copy()
719        all_columns = self._get_outer_select_columns(new_df.expression)
720        all_column_mapping = {column.alias_or_name: column for column in all_columns}
721
722        columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
723        if isinstance(to_replace, dict):
724            old_values = list(to_replace)
725            new_values = list(to_replace.values())
726        elif not old_values and isinstance(to_replace, list):
727            assert isinstance(value, list), "value must be a list since the replacements are a list"
728            assert len(to_replace) == len(
729                value
730            ), "the replacements and values must be the same length"
731            old_values = to_replace
732            new_values = value
733        else:
734            old_values = [to_replace] * len(columns)
735            new_values = [value] * len(columns)
736        old_values = [lit(value) for value in old_values]
737        new_values = [lit(value) for value in new_values]
738
739        replacement_mapping = {}
740        for column in columns:
741            expression = Column(None)
742            for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
743                if i == 0:
744                    expression = F.when(column == old_value, new_value)
745                else:
746                    expression = expression.when(column == old_value, new_value)  # type: ignore
747            replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
748                column.expression.alias_or_name
749            )
750
751        replacement_mapping = {**all_column_mapping, **replacement_mapping}
752        replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
753        new_df = new_df.select(*replacement_columns)
754        return new_df
@operation(Operation.SELECT)
def withColumn( self, colName: str, col: Column) -> DataFrame:
756    @operation(Operation.SELECT)
757    def withColumn(self, colName: str, col: Column) -> DataFrame:
758        col = self._ensure_and_normalize_col(col)
759        existing_col_names = self.expression.named_selects
760        existing_col_index = (
761            existing_col_names.index(colName) if colName in existing_col_names else None
762        )
763        if existing_col_index:
764            expression = self.expression.copy()
765            expression.expressions[existing_col_index] = col.expression
766            return self.copy(expression=expression)
767        return self.copy().select(col.alias(colName), append=True)
@operation(Operation.SELECT)
def withColumnRenamed(self, existing: str, new: str):
769    @operation(Operation.SELECT)
770    def withColumnRenamed(self, existing: str, new: str):
771        expression = self.expression.copy()
772        existing_columns = [
773            expression
774            for expression in expression.expressions
775            if expression.alias_or_name == existing
776        ]
777        if not existing_columns:
778            raise ValueError("Tried to rename a column that doesn't exist")
779        for existing_column in existing_columns:
780            if isinstance(existing_column, exp.Column):
781                existing_column.replace(exp.alias_(existing_column, new))
782            else:
783                existing_column.set("alias", exp.to_identifier(new))
784        return self.copy(expression=expression)
@operation(Operation.SELECT)
def drop( self, *cols: Union[str, Column]) -> DataFrame:
786    @operation(Operation.SELECT)
787    def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
788        all_columns = self._get_outer_select_columns(self.expression)
789        drop_cols = self._ensure_and_normalize_cols(cols)
790        new_columns = [
791            col
792            for col in all_columns
793            if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
794        ]
795        return self.copy().select(*new_columns, append=False)
@operation(Operation.LIMIT)
def limit(self, num: int) -> DataFrame:
797    @operation(Operation.LIMIT)
798    def limit(self, num: int) -> DataFrame:
799        return self.copy(expression=self.expression.limit(num))
@operation(Operation.NO_OP)
def hint( self, name: str, *parameters: Union[str, int, NoneType]) -> DataFrame:
801    @operation(Operation.NO_OP)
802    def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
803        parameter_list = ensure_list(parameters)
804        parameter_columns = (
805            self._ensure_list_of_columns(parameter_list)
806            if parameters
807            else Column.ensure_cols([self.sequence_id])
808        )
809        return self._hint(name, parameter_columns)
@operation(Operation.NO_OP)
def repartition( self, numPartitions: Union[int, <MagicMock id='140403569303536'>], *cols: <MagicMock id='140403569303536'>) -> DataFrame:
811    @operation(Operation.NO_OP)
812    def repartition(
813        self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
814    ) -> DataFrame:
815        num_partition_cols = self._ensure_list_of_columns(numPartitions)
816        columns = self._ensure_and_normalize_cols(cols)
817        args = num_partition_cols + columns
818        return self._hint("repartition", args)
@operation(Operation.NO_OP)
def coalesce(self, numPartitions: int) -> DataFrame:
820    @operation(Operation.NO_OP)
821    def coalesce(self, numPartitions: int) -> DataFrame:
822        num_partitions = Column.ensure_cols([numPartitions])
823        return self._hint("coalesce", num_partitions)
@operation(Operation.NO_OP)
def cache(self) -> DataFrame:
825    @operation(Operation.NO_OP)
826    def cache(self) -> DataFrame:
827        return self._cache(storage_level="MEMORY_AND_DISK")
@operation(Operation.NO_OP)
def persist( self, storageLevel: str = 'MEMORY_AND_DISK_SER') -> DataFrame:
829    @operation(Operation.NO_OP)
830    def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
831        """
832        Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
833        """
834        return self._cache(storageLevel)
class GroupedData:
14class GroupedData:
15    def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
16        self._df = df.copy()
17        self.spark = df.spark
18        self.last_op = last_op
19        self.group_by_cols = group_by_cols
20
21    def _get_function_applied_columns(
22        self, func_name: str, cols: t.Tuple[str, ...]
23    ) -> t.List[Column]:
24        func_name = func_name.lower()
25        return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
26
27    @operation(Operation.SELECT)
28    def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
29        columns = (
30            [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
31            if isinstance(exprs[0], dict)
32            else exprs
33        )
34        cols = self._df._ensure_and_normalize_cols(columns)
35
36        expression = self._df.expression.group_by(
37            *[x.expression for x in self.group_by_cols]
38        ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
39        return self._df.copy(expression=expression)
40
41    def count(self) -> DataFrame:
42        return self.agg(F.count("*").alias("count"))
43
44    def mean(self, *cols: str) -> DataFrame:
45        return self.avg(*cols)
46
47    def avg(self, *cols: str) -> DataFrame:
48        return self.agg(*self._get_function_applied_columns("avg", cols))
49
50    def max(self, *cols: str) -> DataFrame:
51        return self.agg(*self._get_function_applied_columns("max", cols))
52
53    def min(self, *cols: str) -> DataFrame:
54        return self.agg(*self._get_function_applied_columns("min", cols))
55
56    def sum(self, *cols: str) -> DataFrame:
57        return self.agg(*self._get_function_applied_columns("sum", cols))
58
59    def pivot(self, *cols: str) -> DataFrame:
60        raise NotImplementedError("Sum distinct is not currently implemented")
GroupedData( df: DataFrame, group_by_cols: List[Column], last_op: sqlglot.dataframe.sql.operations.Operation)
15    def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
16        self._df = df.copy()
17        self.spark = df.spark
18        self.last_op = last_op
19        self.group_by_cols = group_by_cols
spark
last_op
group_by_cols
@operation(Operation.SELECT)
def agg( self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame:
27    @operation(Operation.SELECT)
28    def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
29        columns = (
30            [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
31            if isinstance(exprs[0], dict)
32            else exprs
33        )
34        cols = self._df._ensure_and_normalize_cols(columns)
35
36        expression = self._df.expression.group_by(
37            *[x.expression for x in self.group_by_cols]
38        ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
39        return self._df.copy(expression=expression)
def count(self) -> DataFrame:
41    def count(self) -> DataFrame:
42        return self.agg(F.count("*").alias("count"))
def mean(self, *cols: str) -> DataFrame:
44    def mean(self, *cols: str) -> DataFrame:
45        return self.avg(*cols)
def avg(self, *cols: str) -> DataFrame:
47    def avg(self, *cols: str) -> DataFrame:
48        return self.agg(*self._get_function_applied_columns("avg", cols))
def max(self, *cols: str) -> DataFrame:
50    def max(self, *cols: str) -> DataFrame:
51        return self.agg(*self._get_function_applied_columns("max", cols))
def min(self, *cols: str) -> DataFrame:
53    def min(self, *cols: str) -> DataFrame:
54        return self.agg(*self._get_function_applied_columns("min", cols))
def sum(self, *cols: str) -> DataFrame:
56    def sum(self, *cols: str) -> DataFrame:
57        return self.agg(*self._get_function_applied_columns("sum", cols))
def pivot(self, *cols: str) -> DataFrame:
59    def pivot(self, *cols: str) -> DataFrame:
60        raise NotImplementedError("Sum distinct is not currently implemented")
class Column:
 16class Column:
 17    def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
 18        from sqlglot.dataframe.sql.session import SparkSession
 19
 20        if isinstance(expression, Column):
 21            expression = expression.expression  # type: ignore
 22        elif expression is None or not isinstance(expression, (str, exp.Expression)):
 23            expression = self._lit(expression).expression  # type: ignore
 24        elif not isinstance(expression, exp.Column):
 25            expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform(
 26                SparkSession().dialect.normalize_identifier, copy=False
 27            )
 28        if expression is None:
 29            raise ValueError(f"Could not parse {expression}")
 30
 31        self.expression: exp.Expression = expression  # type: ignore
 32
 33    def __repr__(self):
 34        return repr(self.expression)
 35
 36    def __hash__(self):
 37        return hash(self.expression)
 38
 39    def __eq__(self, other: ColumnOrLiteral) -> Column:  # type: ignore
 40        return self.binary_op(exp.EQ, other)
 41
 42    def __ne__(self, other: ColumnOrLiteral) -> Column:  # type: ignore
 43        return self.binary_op(exp.NEQ, other)
 44
 45    def __gt__(self, other: ColumnOrLiteral) -> Column:
 46        return self.binary_op(exp.GT, other)
 47
 48    def __ge__(self, other: ColumnOrLiteral) -> Column:
 49        return self.binary_op(exp.GTE, other)
 50
 51    def __lt__(self, other: ColumnOrLiteral) -> Column:
 52        return self.binary_op(exp.LT, other)
 53
 54    def __le__(self, other: ColumnOrLiteral) -> Column:
 55        return self.binary_op(exp.LTE, other)
 56
 57    def __and__(self, other: ColumnOrLiteral) -> Column:
 58        return self.binary_op(exp.And, other)
 59
 60    def __or__(self, other: ColumnOrLiteral) -> Column:
 61        return self.binary_op(exp.Or, other)
 62
 63    def __mod__(self, other: ColumnOrLiteral) -> Column:
 64        return self.binary_op(exp.Mod, other)
 65
 66    def __add__(self, other: ColumnOrLiteral) -> Column:
 67        return self.binary_op(exp.Add, other)
 68
 69    def __sub__(self, other: ColumnOrLiteral) -> Column:
 70        return self.binary_op(exp.Sub, other)
 71
 72    def __mul__(self, other: ColumnOrLiteral) -> Column:
 73        return self.binary_op(exp.Mul, other)
 74
 75    def __truediv__(self, other: ColumnOrLiteral) -> Column:
 76        return self.binary_op(exp.Div, other)
 77
 78    def __div__(self, other: ColumnOrLiteral) -> Column:
 79        return self.binary_op(exp.Div, other)
 80
 81    def __neg__(self) -> Column:
 82        return self.unary_op(exp.Neg)
 83
 84    def __radd__(self, other: ColumnOrLiteral) -> Column:
 85        return self.inverse_binary_op(exp.Add, other)
 86
 87    def __rsub__(self, other: ColumnOrLiteral) -> Column:
 88        return self.inverse_binary_op(exp.Sub, other)
 89
 90    def __rmul__(self, other: ColumnOrLiteral) -> Column:
 91        return self.inverse_binary_op(exp.Mul, other)
 92
 93    def __rdiv__(self, other: ColumnOrLiteral) -> Column:
 94        return self.inverse_binary_op(exp.Div, other)
 95
 96    def __rtruediv__(self, other: ColumnOrLiteral) -> Column:
 97        return self.inverse_binary_op(exp.Div, other)
 98
 99    def __rmod__(self, other: ColumnOrLiteral) -> Column:
100        return self.inverse_binary_op(exp.Mod, other)
101
102    def __pow__(self, power: ColumnOrLiteral, modulo=None):
103        return Column(exp.Pow(this=self.expression, expression=Column(power).expression))
104
105    def __rpow__(self, power: ColumnOrLiteral):
106        return Column(exp.Pow(this=Column(power).expression, expression=self.expression))
107
108    def __invert__(self):
109        return self.unary_op(exp.Not)
110
111    def __rand__(self, other: ColumnOrLiteral) -> Column:
112        return self.inverse_binary_op(exp.And, other)
113
114    def __ror__(self, other: ColumnOrLiteral) -> Column:
115        return self.inverse_binary_op(exp.Or, other)
116
117    @classmethod
118    def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column:
119        return cls(value)
120
121    @classmethod
122    def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
123        return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
124
125    @classmethod
126    def _lit(cls, value: ColumnOrLiteral) -> Column:
127        if isinstance(value, dict):
128            columns = [cls._lit(v).alias(k).expression for k, v in value.items()]
129            return cls(exp.Struct(expressions=columns))
130        return cls(exp.convert(value))
131
132    @classmethod
133    def invoke_anonymous_function(
134        cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
135    ) -> Column:
136        columns = [] if column is None else [cls.ensure_col(column)]
137        column_args = [cls.ensure_col(arg) for arg in args]
138        expressions = [x.expression for x in columns + column_args]
139        new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
140        return Column(new_expression)
141
142    @classmethod
143    def invoke_expression_over_column(
144        cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
145    ) -> Column:
146        ensured_column = None if column is None else cls.ensure_col(column)
147        ensure_expression_values = {
148            k: (
149                [Column.ensure_col(x).expression for x in v]
150                if is_iterable(v)
151                else Column.ensure_col(v).expression
152            )
153            for k, v in kwargs.items()
154            if v is not None
155        }
156        new_expression = (
157            callable_expression(**ensure_expression_values)
158            if ensured_column is None
159            else callable_expression(
160                this=ensured_column.column_expression, **ensure_expression_values
161            )
162        )
163        return Column(new_expression)
164
165    def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
166        return Column(
167            klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
168        )
169
170    def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
171        return Column(
172            klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
173        )
174
175    def unary_op(self, klass: t.Callable, **kwargs) -> Column:
176        return Column(klass(this=self.column_expression, **kwargs))
177
178    @property
179    def is_alias(self):
180        return isinstance(self.expression, exp.Alias)
181
182    @property
183    def is_column(self):
184        return isinstance(self.expression, exp.Column)
185
186    @property
187    def column_expression(self) -> t.Union[exp.Column, exp.Literal]:
188        return self.expression.unalias()
189
190    @property
191    def alias_or_name(self) -> str:
192        return self.expression.alias_or_name
193
194    @classmethod
195    def ensure_literal(cls, value) -> Column:
196        from sqlglot.dataframe.sql.functions import lit
197
198        if isinstance(value, cls):
199            value = value.expression
200        if not isinstance(value, exp.Literal):
201            return lit(value)
202        return Column(value)
203
204    def copy(self) -> Column:
205        return Column(self.expression.copy())
206
207    def set_table_name(self, table_name: str, copy=False) -> Column:
208        expression = self.expression.copy() if copy else self.expression
209        expression.set("table", exp.to_identifier(table_name))
210        return Column(expression)
211
212    def sql(self, **kwargs) -> str:
213        from sqlglot.dataframe.sql.session import SparkSession
214
215        return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
216
217    def alias(self, name: str) -> Column:
218        from sqlglot.dataframe.sql.session import SparkSession
219
220        dialect = SparkSession().dialect
221        alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect)
222        new_expression = exp.alias_(
223            self.column_expression,
224            alias.this if isinstance(alias, exp.Column) else name,
225            dialect=dialect,
226        )
227        return Column(new_expression)
228
229    def asc(self) -> Column:
230        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
231        return Column(new_expression)
232
233    def desc(self) -> Column:
234        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
235        return Column(new_expression)
236
237    asc_nulls_first = asc
238
239    def asc_nulls_last(self) -> Column:
240        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
241        return Column(new_expression)
242
243    def desc_nulls_first(self) -> Column:
244        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
245        return Column(new_expression)
246
247    desc_nulls_last = desc
248
249    def when(self, condition: Column, value: t.Any) -> Column:
250        from sqlglot.dataframe.sql.functions import when
251
252        column_with_if = when(condition, value)
253        if not isinstance(self.expression, exp.Case):
254            return column_with_if
255        new_column = self.copy()
256        new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
257        return new_column
258
259    def otherwise(self, value: t.Any) -> Column:
260        from sqlglot.dataframe.sql.functions import lit
261
262        true_value = value if isinstance(value, Column) else lit(value)
263        new_column = self.copy()
264        new_column.expression.set("default", true_value.column_expression)
265        return new_column
266
267    def isNull(self) -> Column:
268        new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
269        return Column(new_expression)
270
271    def isNotNull(self) -> Column:
272        new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
273        return Column(new_expression)
274
275    def cast(self, dataType: t.Union[str, DataType]) -> Column:
276        """
277        Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
278        Sqlglot doesn't currently replicate this class so it only accepts a string
279        """
280        from sqlglot.dataframe.sql.session import SparkSession
281
282        if isinstance(dataType, DataType):
283            dataType = dataType.simpleString()
284        return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect))
285
286    def startswith(self, value: t.Union[str, Column]) -> Column:
287        value = self._lit(value) if not isinstance(value, Column) else value
288        return self.invoke_anonymous_function(self, "STARTSWITH", value)
289
290    def endswith(self, value: t.Union[str, Column]) -> Column:
291        value = self._lit(value) if not isinstance(value, Column) else value
292        return self.invoke_anonymous_function(self, "ENDSWITH", value)
293
294    def rlike(self, regexp: str) -> Column:
295        return self.invoke_expression_over_column(
296            column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
297        )
298
299    def like(self, other: str):
300        return self.invoke_expression_over_column(
301            self, exp.Like, expression=self._lit(other).expression
302        )
303
304    def ilike(self, other: str):
305        return self.invoke_expression_over_column(
306            self, exp.ILike, expression=self._lit(other).expression
307        )
308
309    def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
310        startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
311        length = self._lit(length) if not isinstance(length, Column) else length
312        return Column.invoke_expression_over_column(
313            self, exp.Substring, start=startPos.expression, length=length.expression
314        )
315
316    def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
317        columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
318        expressions = [self._lit(x).expression for x in columns]
319        return Column.invoke_expression_over_column(self, exp.In, expressions=expressions)  # type: ignore
320
321    def between(
322        self,
323        lowerBound: t.Union[ColumnOrLiteral],
324        upperBound: t.Union[ColumnOrLiteral],
325    ) -> Column:
326        lower_bound_exp = (
327            self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
328        )
329        upper_bound_exp = (
330            self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
331        )
332        return Column(
333            exp.Between(
334                this=self.column_expression,
335                low=lower_bound_exp.expression,
336                high=upper_bound_exp.expression,
337            )
338        )
339
340    def over(self, window: WindowSpec) -> Column:
341        window_expression = window.expression.copy()
342        window_expression.set("this", self.column_expression)
343        return Column(window_expression)
Column( expression: Union[<MagicMock id='140403571263424'>, sqlglot.expressions.Expression, NoneType])
17    def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
18        from sqlglot.dataframe.sql.session import SparkSession
19
20        if isinstance(expression, Column):
21            expression = expression.expression  # type: ignore
22        elif expression is None or not isinstance(expression, (str, exp.Expression)):
23            expression = self._lit(expression).expression  # type: ignore
24        elif not isinstance(expression, exp.Column):
25            expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform(
26                SparkSession().dialect.normalize_identifier, copy=False
27            )
28        if expression is None:
29            raise ValueError(f"Could not parse {expression}")
30
31        self.expression: exp.Expression = expression  # type: ignore
@classmethod
def ensure_col( cls, value: Union[<MagicMock id='140403571263424'>, sqlglot.expressions.Expression, NoneType]) -> Column:
117    @classmethod
118    def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column:
119        return cls(value)
@classmethod
def ensure_cols( cls, args: List[Union[<MagicMock id='140403571263424'>, sqlglot.expressions.Expression]]) -> List[Column]:
121    @classmethod
122    def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
123        return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
@classmethod
def invoke_anonymous_function( cls, column: Optional[<MagicMock id='140403571263424'>], func_name: str, *args: Optional[<MagicMock id='140403571263424'>]) -> Column:
132    @classmethod
133    def invoke_anonymous_function(
134        cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
135    ) -> Column:
136        columns = [] if column is None else [cls.ensure_col(column)]
137        column_args = [cls.ensure_col(arg) for arg in args]
138        expressions = [x.expression for x in columns + column_args]
139        new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
140        return Column(new_expression)
@classmethod
def invoke_expression_over_column( cls, column: Optional[<MagicMock id='140403571263424'>], callable_expression: Callable, **kwargs) -> Column:
142    @classmethod
143    def invoke_expression_over_column(
144        cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
145    ) -> Column:
146        ensured_column = None if column is None else cls.ensure_col(column)
147        ensure_expression_values = {
148            k: (
149                [Column.ensure_col(x).expression for x in v]
150                if is_iterable(v)
151                else Column.ensure_col(v).expression
152            )
153            for k, v in kwargs.items()
154            if v is not None
155        }
156        new_expression = (
157            callable_expression(**ensure_expression_values)
158            if ensured_column is None
159            else callable_expression(
160                this=ensured_column.column_expression, **ensure_expression_values
161            )
162        )
163        return Column(new_expression)
def binary_op( self, klass: Callable, other: <MagicMock id='140403571263424'>, **kwargs) -> Column:
165    def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
166        return Column(
167            klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
168        )
def inverse_binary_op( self, klass: Callable, other: <MagicMock id='140403571263424'>, **kwargs) -> Column:
170    def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
171        return Column(
172            klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
173        )
def unary_op(self, klass: Callable, **kwargs) -> Column:
175    def unary_op(self, klass: t.Callable, **kwargs) -> Column:
176        return Column(klass(this=self.column_expression, **kwargs))
is_alias
178    @property
179    def is_alias(self):
180        return isinstance(self.expression, exp.Alias)
is_column
182    @property
183    def is_column(self):
184        return isinstance(self.expression, exp.Column)
column_expression: Union[sqlglot.expressions.Column, sqlglot.expressions.Literal]
186    @property
187    def column_expression(self) -> t.Union[exp.Column, exp.Literal]:
188        return self.expression.unalias()
alias_or_name: str
190    @property
191    def alias_or_name(self) -> str:
192        return self.expression.alias_or_name
@classmethod
def ensure_literal(cls, value) -> Column:
194    @classmethod
195    def ensure_literal(cls, value) -> Column:
196        from sqlglot.dataframe.sql.functions import lit
197
198        if isinstance(value, cls):
199            value = value.expression
200        if not isinstance(value, exp.Literal):
201            return lit(value)
202        return Column(value)
def copy(self) -> Column:
204    def copy(self) -> Column:
205        return Column(self.expression.copy())
def set_table_name(self, table_name: str, copy=False) -> Column:
207    def set_table_name(self, table_name: str, copy=False) -> Column:
208        expression = self.expression.copy() if copy else self.expression
209        expression.set("table", exp.to_identifier(table_name))
210        return Column(expression)
def sql(self, **kwargs) -> str:
212    def sql(self, **kwargs) -> str:
213        from sqlglot.dataframe.sql.session import SparkSession
214
215        return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
def alias(self, name: str) -> Column:
217    def alias(self, name: str) -> Column:
218        from sqlglot.dataframe.sql.session import SparkSession
219
220        dialect = SparkSession().dialect
221        alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect)
222        new_expression = exp.alias_(
223            self.column_expression,
224            alias.this if isinstance(alias, exp.Column) else name,
225            dialect=dialect,
226        )
227        return Column(new_expression)
def asc(self) -> Column:
229    def asc(self) -> Column:
230        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
231        return Column(new_expression)
def desc(self) -> Column:
233    def desc(self) -> Column:
234        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
235        return Column(new_expression)
def asc_nulls_first(self) -> Column:
229    def asc(self) -> Column:
230        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
231        return Column(new_expression)
def asc_nulls_last(self) -> Column:
239    def asc_nulls_last(self) -> Column:
240        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
241        return Column(new_expression)
def desc_nulls_first(self) -> Column:
243    def desc_nulls_first(self) -> Column:
244        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
245        return Column(new_expression)
def desc_nulls_last(self) -> Column:
233    def desc(self) -> Column:
234        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
235        return Column(new_expression)
def when( self, condition: Column, value: Any) -> Column:
249    def when(self, condition: Column, value: t.Any) -> Column:
250        from sqlglot.dataframe.sql.functions import when
251
252        column_with_if = when(condition, value)
253        if not isinstance(self.expression, exp.Case):
254            return column_with_if
255        new_column = self.copy()
256        new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
257        return new_column
def otherwise(self, value: Any) -> Column:
259    def otherwise(self, value: t.Any) -> Column:
260        from sqlglot.dataframe.sql.functions import lit
261
262        true_value = value if isinstance(value, Column) else lit(value)
263        new_column = self.copy()
264        new_column.expression.set("default", true_value.column_expression)
265        return new_column
def isNull(self) -> Column:
267    def isNull(self) -> Column:
268        new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
269        return Column(new_expression)
def isNotNull(self) -> Column:
271    def isNotNull(self) -> Column:
272        new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
273        return Column(new_expression)
def cast( self, dataType: Union[str, sqlglot.dataframe.sql.types.DataType]) -> Column:
275    def cast(self, dataType: t.Union[str, DataType]) -> Column:
276        """
277        Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
278        Sqlglot doesn't currently replicate this class so it only accepts a string
279        """
280        from sqlglot.dataframe.sql.session import SparkSession
281
282        if isinstance(dataType, DataType):
283            dataType = dataType.simpleString()
284        return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect))

Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string

def startswith( self, value: Union[str, Column]) -> Column:
286    def startswith(self, value: t.Union[str, Column]) -> Column:
287        value = self._lit(value) if not isinstance(value, Column) else value
288        return self.invoke_anonymous_function(self, "STARTSWITH", value)
def endswith( self, value: Union[str, Column]) -> Column:
290    def endswith(self, value: t.Union[str, Column]) -> Column:
291        value = self._lit(value) if not isinstance(value, Column) else value
292        return self.invoke_anonymous_function(self, "ENDSWITH", value)
def rlike(self, regexp: str) -> Column:
294    def rlike(self, regexp: str) -> Column:
295        return self.invoke_expression_over_column(
296            column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
297        )
def like(self, other: str):
299    def like(self, other: str):
300        return self.invoke_expression_over_column(
301            self, exp.Like, expression=self._lit(other).expression
302        )
def ilike(self, other: str):
304    def ilike(self, other: str):
305        return self.invoke_expression_over_column(
306            self, exp.ILike, expression=self._lit(other).expression
307        )
def substr( self, startPos: Union[int, Column], length: Union[int, Column]) -> Column:
309    def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
310        startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
311        length = self._lit(length) if not isinstance(length, Column) else length
312        return Column.invoke_expression_over_column(
313            self, exp.Substring, start=startPos.expression, length=length.expression
314        )
def isin( self, *cols: Union[<MagicMock id='140403571263424'>, Iterable[<MagicMock id='140403571263424'>]]):
316    def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
317        columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
318        expressions = [self._lit(x).expression for x in columns]
319        return Column.invoke_expression_over_column(self, exp.In, expressions=expressions)  # type: ignore
def between( self, lowerBound: <MagicMock id='140403571263424'>, upperBound: <MagicMock id='140403571263424'>) -> Column:
321    def between(
322        self,
323        lowerBound: t.Union[ColumnOrLiteral],
324        upperBound: t.Union[ColumnOrLiteral],
325    ) -> Column:
326        lower_bound_exp = (
327            self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
328        )
329        upper_bound_exp = (
330            self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
331        )
332        return Column(
333            exp.Between(
334                this=self.column_expression,
335                low=lower_bound_exp.expression,
336                high=upper_bound_exp.expression,
337            )
338        )
def over( self, window: <MagicMock id='140403569197312'>) -> Column:
340    def over(self, window: WindowSpec) -> Column:
341        window_expression = window.expression.copy()
342        window_expression.set("this", self.column_expression)
343        return Column(window_expression)
class DataFrameNaFunctions:
837class DataFrameNaFunctions:
838    def __init__(self, df: DataFrame):
839        self.df = df
840
841    def drop(
842        self,
843        how: str = "any",
844        thresh: t.Optional[int] = None,
845        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
846    ) -> DataFrame:
847        return self.df.dropna(how=how, thresh=thresh, subset=subset)
848
849    def fill(
850        self,
851        value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
852        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
853    ) -> DataFrame:
854        return self.df.fillna(value=value, subset=subset)
855
856    def replace(
857        self,
858        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
859        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
860        subset: t.Optional[t.Union[str, t.List[str]]] = None,
861    ) -> DataFrame:
862        return self.df.replace(to_replace=to_replace, value=value, subset=subset)
DataFrameNaFunctions(df: DataFrame)
838    def __init__(self, df: DataFrame):
839        self.df = df
df
def drop( self, how: str = 'any', thresh: Optional[int] = None, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> DataFrame:
841    def drop(
842        self,
843        how: str = "any",
844        thresh: t.Optional[int] = None,
845        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
846    ) -> DataFrame:
847        return self.df.dropna(how=how, thresh=thresh, subset=subset)
def fill( self, value: Union[int, bool, float, str, Dict[str, Any]], subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> DataFrame:
849    def fill(
850        self,
851        value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
852        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
853    ) -> DataFrame:
854        return self.df.fillna(value=value, subset=subset)
def replace( self, to_replace: Union[bool, int, float, str, List, Dict], value: Union[bool, int, float, str, List, NoneType] = None, subset: Union[str, List[str], NoneType] = None) -> DataFrame:
856    def replace(
857        self,
858        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
859        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
860        subset: t.Optional[t.Union[str, t.List[str]]] = None,
861    ) -> DataFrame:
862        return self.df.replace(to_replace=to_replace, value=value, subset=subset)
class Window:
15class Window:
16    _JAVA_MIN_LONG = -(1 << 63)  # -9223372036854775808
17    _JAVA_MAX_LONG = (1 << 63) - 1  # 9223372036854775807
18    _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG)
19    _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG)
20
21    unboundedPreceding: int = _JAVA_MIN_LONG
22
23    unboundedFollowing: int = _JAVA_MAX_LONG
24
25    currentRow: int = 0
26
27    @classmethod
28    def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
29        return WindowSpec().partitionBy(*cols)
30
31    @classmethod
32    def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
33        return WindowSpec().orderBy(*cols)
34
35    @classmethod
36    def rowsBetween(cls, start: int, end: int) -> WindowSpec:
37        return WindowSpec().rowsBetween(start, end)
38
39    @classmethod
40    def rangeBetween(cls, start: int, end: int) -> WindowSpec:
41        return WindowSpec().rangeBetween(start, end)
unboundedPreceding: int = -9223372036854775808
unboundedFollowing: int = 9223372036854775807
currentRow: int = 0
@classmethod
def partitionBy( cls, *cols: Union[<MagicMock id='140403569952800'>, List[<MagicMock id='140403569952800'>]]) -> WindowSpec:
27    @classmethod
28    def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
29        return WindowSpec().partitionBy(*cols)
@classmethod
def orderBy( cls, *cols: Union[<MagicMock id='140403569952800'>, List[<MagicMock id='140403569952800'>]]) -> WindowSpec:
31    @classmethod
32    def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
33        return WindowSpec().orderBy(*cols)
@classmethod
def rowsBetween(cls, start: int, end: int) -> WindowSpec:
35    @classmethod
36    def rowsBetween(cls, start: int, end: int) -> WindowSpec:
37        return WindowSpec().rowsBetween(start, end)
@classmethod
def rangeBetween(cls, start: int, end: int) -> WindowSpec:
39    @classmethod
40    def rangeBetween(cls, start: int, end: int) -> WindowSpec:
41        return WindowSpec().rangeBetween(start, end)
class WindowSpec:
 44class WindowSpec:
 45    def __init__(self, expression: exp.Expression = exp.Window()):
 46        self.expression = expression
 47
 48    def copy(self):
 49        return WindowSpec(self.expression.copy())
 50
 51    def sql(self, **kwargs) -> str:
 52        from sqlglot.dataframe.sql.session import SparkSession
 53
 54        return self.expression.sql(dialect=SparkSession().dialect, **kwargs)
 55
 56    def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
 57        from sqlglot.dataframe.sql.column import Column
 58
 59        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
 60        expressions = [Column.ensure_col(x).expression for x in cols]
 61        window_spec = self.copy()
 62        partition_by_expressions = window_spec.expression.args.get("partition_by", [])
 63        partition_by_expressions.extend(expressions)
 64        window_spec.expression.set("partition_by", partition_by_expressions)
 65        return window_spec
 66
 67    def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
 68        from sqlglot.dataframe.sql.column import Column
 69
 70        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
 71        expressions = [Column.ensure_col(x).expression for x in cols]
 72        window_spec = self.copy()
 73        if window_spec.expression.args.get("order") is None:
 74            window_spec.expression.set("order", exp.Order(expressions=[]))
 75        order_by = window_spec.expression.args["order"].expressions
 76        order_by.extend(expressions)
 77        window_spec.expression.args["order"].set("expressions", order_by)
 78        return window_spec
 79
 80    def _calc_start_end(
 81        self, start: int, end: int
 82    ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
 83        kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {
 84            "start_side": None,
 85            "end_side": None,
 86        }
 87        if start == Window.currentRow:
 88            kwargs["start"] = "CURRENT ROW"
 89        else:
 90            kwargs = {
 91                **kwargs,
 92                **{
 93                    "start_side": "PRECEDING",
 94                    "start": (
 95                        "UNBOUNDED"
 96                        if start <= Window.unboundedPreceding
 97                        else F.lit(start).expression
 98                    ),
 99                },
100            }
101        if end == Window.currentRow:
102            kwargs["end"] = "CURRENT ROW"
103        else:
104            kwargs = {
105                **kwargs,
106                **{
107                    "end_side": "FOLLOWING",
108                    "end": (
109                        "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression
110                    ),
111                },
112            }
113        return kwargs
114
115    def rowsBetween(self, start: int, end: int) -> WindowSpec:
116        window_spec = self.copy()
117        spec = self._calc_start_end(start, end)
118        spec["kind"] = "ROWS"
119        window_spec.expression.set(
120            "spec",
121            exp.WindowSpec(
122                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
123            ),
124        )
125        return window_spec
126
127    def rangeBetween(self, start: int, end: int) -> WindowSpec:
128        window_spec = self.copy()
129        spec = self._calc_start_end(start, end)
130        spec["kind"] = "RANGE"
131        window_spec.expression.set(
132            "spec",
133            exp.WindowSpec(
134                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
135            ),
136        )
137        return window_spec
WindowSpec(expression: sqlglot.expressions.Expression = Window())
45    def __init__(self, expression: exp.Expression = exp.Window()):
46        self.expression = expression
expression
def copy(self):
48    def copy(self):
49        return WindowSpec(self.expression.copy())
def sql(self, **kwargs) -> str:
51    def sql(self, **kwargs) -> str:
52        from sqlglot.dataframe.sql.session import SparkSession
53
54        return self.expression.sql(dialect=SparkSession().dialect, **kwargs)
def partitionBy( self, *cols: Union[<MagicMock id='140403569952800'>, List[<MagicMock id='140403569952800'>]]) -> WindowSpec:
56    def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
57        from sqlglot.dataframe.sql.column import Column
58
59        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
60        expressions = [Column.ensure_col(x).expression for x in cols]
61        window_spec = self.copy()
62        partition_by_expressions = window_spec.expression.args.get("partition_by", [])
63        partition_by_expressions.extend(expressions)
64        window_spec.expression.set("partition_by", partition_by_expressions)
65        return window_spec
def orderBy( self, *cols: Union[<MagicMock id='140403569952800'>, List[<MagicMock id='140403569952800'>]]) -> WindowSpec:
67    def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
68        from sqlglot.dataframe.sql.column import Column
69
70        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
71        expressions = [Column.ensure_col(x).expression for x in cols]
72        window_spec = self.copy()
73        if window_spec.expression.args.get("order") is None:
74            window_spec.expression.set("order", exp.Order(expressions=[]))
75        order_by = window_spec.expression.args["order"].expressions
76        order_by.extend(expressions)
77        window_spec.expression.args["order"].set("expressions", order_by)
78        return window_spec
def rowsBetween(self, start: int, end: int) -> WindowSpec:
115    def rowsBetween(self, start: int, end: int) -> WindowSpec:
116        window_spec = self.copy()
117        spec = self._calc_start_end(start, end)
118        spec["kind"] = "ROWS"
119        window_spec.expression.set(
120            "spec",
121            exp.WindowSpec(
122                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
123            ),
124        )
125        return window_spec
def rangeBetween(self, start: int, end: int) -> WindowSpec:
127    def rangeBetween(self, start: int, end: int) -> WindowSpec:
128        window_spec = self.copy()
129        spec = self._calc_start_end(start, end)
130        spec["kind"] = "RANGE"
131        window_spec.expression.set(
132            "spec",
133            exp.WindowSpec(
134                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
135            ),
136        )
137        return window_spec
class DataFrameReader:
15class DataFrameReader:
16    def __init__(self, spark: SparkSession):
17        self.spark = spark
18
19    def table(self, tableName: str) -> DataFrame:
20        from sqlglot.dataframe.sql.dataframe import DataFrame
21        from sqlglot.dataframe.sql.session import SparkSession
22
23        sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect)
24
25        return DataFrame(
26            self.spark,
27            exp.Select()
28            .from_(
29                exp.to_table(tableName, dialect=SparkSession().dialect).transform(
30                    SparkSession().dialect.normalize_identifier
31                )
32            )
33            .select(
34                *(
35                    column
36                    for column in sqlglot.schema.column_names(
37                        tableName, dialect=SparkSession().dialect
38                    )
39                )
40            ),
41        )
DataFrameReader(spark: SparkSession)
16    def __init__(self, spark: SparkSession):
17        self.spark = spark
spark
def table(self, tableName: str) -> DataFrame:
19    def table(self, tableName: str) -> DataFrame:
20        from sqlglot.dataframe.sql.dataframe import DataFrame
21        from sqlglot.dataframe.sql.session import SparkSession
22
23        sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect)
24
25        return DataFrame(
26            self.spark,
27            exp.Select()
28            .from_(
29                exp.to_table(tableName, dialect=SparkSession().dialect).transform(
30                    SparkSession().dialect.normalize_identifier
31                )
32            )
33            .select(
34                *(
35                    column
36                    for column in sqlglot.schema.column_names(
37                        tableName, dialect=SparkSession().dialect
38                    )
39                )
40            ),
41        )
class DataFrameWriter:
 44class DataFrameWriter:
 45    def __init__(
 46        self,
 47        df: DataFrame,
 48        spark: t.Optional[SparkSession] = None,
 49        mode: t.Optional[str] = None,
 50        by_name: bool = False,
 51    ):
 52        self._df = df
 53        self._spark = spark or df.spark
 54        self._mode = mode
 55        self._by_name = by_name
 56
 57    def copy(self, **kwargs) -> DataFrameWriter:
 58        return DataFrameWriter(
 59            **{
 60                k[1:] if k.startswith("_") else k: v
 61                for k, v in object_to_dict(self, **kwargs).items()
 62            }
 63        )
 64
 65    def sql(self, **kwargs) -> t.List[str]:
 66        return self._df.sql(**kwargs)
 67
 68    def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
 69        return self.copy(_mode=saveMode)
 70
 71    @property
 72    def byName(self):
 73        return self.copy(by_name=True)
 74
 75    def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
 76        from sqlglot.dataframe.sql.session import SparkSession
 77
 78        output_expression_container = exp.Insert(
 79            **{
 80                "this": exp.to_table(tableName),
 81                "overwrite": overwrite,
 82            }
 83        )
 84        df = self._df.copy(output_expression_container=output_expression_container)
 85        if self._by_name:
 86            columns = sqlglot.schema.column_names(
 87                tableName, only_visible=True, dialect=SparkSession().dialect
 88            )
 89            df = df._convert_leaf_to_cte().select(*columns)
 90
 91        return self.copy(_df=df)
 92
 93    def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
 94        if format is not None:
 95            raise NotImplementedError("Providing Format in the save as table is not supported")
 96        exists, replace, mode = None, None, mode or str(self._mode)
 97        if mode == "append":
 98            return self.insertInto(name)
 99        if mode == "ignore":
100            exists = True
101        if mode == "overwrite":
102            replace = True
103        output_expression_container = exp.Create(
104            this=exp.to_table(name),
105            kind="TABLE",
106            exists=exists,
107            replace=replace,
108        )
109        return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
DataFrameWriter( df: DataFrame, spark: Optional[SparkSession] = None, mode: Optional[str] = None, by_name: bool = False)
45    def __init__(
46        self,
47        df: DataFrame,
48        spark: t.Optional[SparkSession] = None,
49        mode: t.Optional[str] = None,
50        by_name: bool = False,
51    ):
52        self._df = df
53        self._spark = spark or df.spark
54        self._mode = mode
55        self._by_name = by_name
def copy(self, **kwargs) -> DataFrameWriter:
57    def copy(self, **kwargs) -> DataFrameWriter:
58        return DataFrameWriter(
59            **{
60                k[1:] if k.startswith("_") else k: v
61                for k, v in object_to_dict(self, **kwargs).items()
62            }
63        )
def sql(self, **kwargs) -> List[str]:
65    def sql(self, **kwargs) -> t.List[str]:
66        return self._df.sql(**kwargs)
def mode( self, saveMode: Optional[str]) -> DataFrameWriter:
68    def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
69        return self.copy(_mode=saveMode)
byName
71    @property
72    def byName(self):
73        return self.copy(by_name=True)
def insertInto( self, tableName: str, overwrite: Optional[bool] = None) -> DataFrameWriter:
75    def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
76        from sqlglot.dataframe.sql.session import SparkSession
77
78        output_expression_container = exp.Insert(
79            **{
80                "this": exp.to_table(tableName),
81                "overwrite": overwrite,
82            }
83        )
84        df = self._df.copy(output_expression_container=output_expression_container)
85        if self._by_name:
86            columns = sqlglot.schema.column_names(
87                tableName, only_visible=True, dialect=SparkSession().dialect
88            )
89            df = df._convert_leaf_to_cte().select(*columns)
90
91        return self.copy(_df=df)
def saveAsTable( self, name: str, format: Optional[str] = None, mode: Optional[str] = None):
 93    def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
 94        if format is not None:
 95            raise NotImplementedError("Providing Format in the save as table is not supported")
 96        exists, replace, mode = None, None, mode or str(self._mode)
 97        if mode == "append":
 98            return self.insertInto(name)
 99        if mode == "ignore":
100            exists = True
101        if mode == "overwrite":
102            replace = True
103        output_expression_container = exp.Create(
104            this=exp.to_table(name),
105            kind="TABLE",
106            exists=exists,
107            replace=replace,
108        )
109        return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))