Mastering PySpark ⚡️ : Best Practices for Efficient Coding

appear in style and sparkly

Author

Joost de Theije + LLM

Published

January 25, 2025

Abstract
PySpark is a powerful tool for processing massive datasets, but it presents unique challenges, especially for experienced Python developers. This guide provides best practices for efficient PySpark coding, helping you avoid common pitfalls and write clean, performant code. By following these tips, you’ll enhance your PySpark skills and improve the reliability and performance of your data processing tasks. As the PySpark syntax draws on the JVM heritage of Spark and therefore implements code patterns that may be unfamiliar. Adapted from here and here

1 Standard Practices for Efficient Coding 🧠

1.1 Import PySpark Modules with Default Aliases

Always start with the following imports: use the F prefix for all PySpark functions, T prefix for all PySpark types, and W for Windows. Adding the #noqa comment at the end of the line will disable linters (like flake8) from flagging the line as an error. Importing these will enable you to use the types and functions without importing each item separately, resulting in a speedy workflow. This will also eliminate the oopsie when you are applying the sum from the Python library instead of the PySpark library☹.

from pyspark.sql import DataFrame as df, functions as F, types as T, Window as W #noqa

F.sum()
T.IntegerType()
W.partitionBy()

1.2 Type Annotations and Docstrings in Functions

Type annotations and docstrings are excellent practices with numerous benefits, including: - Improved readability: They make the code easier to understand. - Enhanced IDE support: They enable autocompletion, type checking, and other features. - Better documentation: They allow you to document your code as you write it, storing the solution with the problem. - Static type checking: Tools like mypy can help catch errors early.

from pyspark.sql import DataFrame

# bad
def calculate_daily_scoops(df, flavor_col, date_col):
    return df.groupBy(date_col, flavor_col).agg(F.sum("scoops_sold").alias("total_scoops"))

# good
def calculate_daily_scoops(df: DataFrame, flavor_col: str, date_col: str) -> DataFrame:
    """Calculate total scoops sold per day and flavor.

    Groups the input DataFrame by date and flavor, aggregating the total number
    of scoops sold for each unique combination.

    Parameters
    ----------
    df : DataFrame
        Input DataFrame containing sales data.
    flavor_col : str
        Name of the column containing ice cream flavor information.
    date_col : str
        Name of the column containing date information.

    Returns
    -------
    DataFrame
        DataFrame with total scoops sold, grouped by date and flavor.
        Contains columns for date, flavor, and total scoops.

    Examples
    --------
    >>> import pyspark.sql.functions as F
    >>> sales_df = spark.createDataFrame(...)
    >>> daily_scoops = calculate_daily_scoops(sales_df, 'flavor', 'sale_date')
    >>> daily_scoops.show()
    """
    return df.groupBy(date_col, flavor_col).agg(
        F.sum("scoops_sold").alias("total_scoops")
    )

1.3 Formatting

black, just use black.

2 Implicit column selection

# bad
df = df.select(F.lower(F.col('colA')), F.upper(F.col('colB')))

# good
df = df.select(F.lower(df1.colA), F.upper(df2.colB))

# better - since Spark 3.0
df = df.select(F.lower('colA'), F.upper('colB'))

In most cases, the third option is best. Spark 3.0 expanded the scenarios where this approach works. However, when using strings as column selection isn’t feasible, fall back to the second option.

2.1 When to Deviate

In some contexts, you might need to access columns from multiple dataframes with overlapping names. For example, in matching expressions like df.join(df2, on=(df.key == df2.key), how='left'). In such cases, it’s acceptable to reference columns directly by their dataframe. Please see Joins for more details.

3 Logical operations

Logical operations, which often reside inside .filter() or F.when(), need to be readable. Keep logic expressions inside the same code block to three (3) expressions at most. If they grow longer, it is often a sign that the code can and should be extracted and/or simplified. Extracting complex logical operations into variables makes the code easier to read and reason about, which in turn also reduces bugs.

# bad
F.when( (F.col('prod_status') == 'Delivered') | (((F.datediff('deliveryDate_actual', 'current_date') < 0) & ((F.col('currentRegistration') != '') | ((F.datediff('deliveryDate_actual', 'current_date') < 0) & ((F.col('originalOperator') != '') | (F.col('currentOperator') != '')))))), 'In Service')

The code above can be simplified in different ways. To start, focus on grouping the logic steps in named variables. PySpark requires that expressions are wrapped with parentheses -> (). This, mixed with actual parentheses to group logical operations, can hurt readability. For example, the code above has a redundant condition -> (F.datediff(df.deliveryDate_actual, df.current_date) < 0) that is very hard to spot.

# better
has_operator = ((F.col('originalOperator') != '') | (F.col('currentOperator') != ''))
delivery_date_passed = (F.datediff('deliveryDate_actual', 'current_date') < 0)
has_registration = (F.col('currentRegistration').rlike('.+'))
is_delivered = (F.col('prod_status') == 'Delivered')

F.when(is_delivered | (delivery_date_passed & (has_registration | has_operator)), 'In Service')

The above example drops the redundant expression and is easier to read. We can improve it further by reducing the number of operations and grouping the ones that have a business meaning together.

# good
has_registration = F.col("currentRegistration").rlike(".+")
has_operator = (F.col("originalOperator") != "") | (F.col("currentOperator") != "")
is_active = has_registration | has_operator
is_delivered = F.col("prod_status") == "Delivered"
delivery_date_passed = F.datediff("deliveryDate_actual", "current_date") < 0

F.when(is_delivered | (delivery_date_passed & is_active), "In Service")

The F.when expression is now readable, and the desired behavior is clear to anyone reviewing this code. Surprise, that person is going to be you in 6 months. The reader only needs to visit the individual expressions if they suspect there is an error. It also makes each chunk of logic easy to reason about and test. You do have unit tests, right?. If you want, these expressions can be extracted into their own functions.

4 Use select to specify a schema contract

Doing a select at the beginning or an end of a PySpark transform, specifies the contract with both the reader and the code about the expected dataframe schema for inputs and outputs. Keep select statements as simple as possible. Apply only one function from spark.sql.function per column, plus an optional .alias() to give it a meaningful name. Expressions involving more than one dataframe, or conditional operations like .when() are discouraged from being used in a select.

A select at the beginning or end of a PySpark transform specifies the contract with both the reader and the code about the expected dataframe schema for inputs and outputs. Keep select statements as simple as possible. Apply only one function from spark.sql.function per column, plus an optional .alias() to give it a meaningful name. Expressions involving more than one dataframe or conditional operations like .when() are discouraged in a select.

# bad
aircraft = aircraft.select(
    'aircraft_id',
    'aircraft_msn',
    F.col('aircraft_registration').alias('registration'),
    'aircraft_type',
    F.avg('staleness').alias('avg_staleness'),
    F.col('number_of_economy_seats').cast('long'),
    F.avg('flight_hours').alias('avg_flight_hours'),
    'operator_code',
    F.col('number_of_business_seats').cast('long'),
)

Unless order matters to you, try to cluster together operations of the same type, reducing the cognitive load on the reader of the code.

# good
aircraft = aircraft.select(
    "aircraft_id",
    "aircraft_msn",
    "aircraft_type",
    "operator_code",
    F.col("aircraft_registration").alias("registration"),
    F.col("number_of_economy_seats").cast("long"),
    F.col("number_of_business_seats").cast("long"),
    F.avg("staleness").alias("avg_staleness"),
    F.avg("flight_hours").alias("avg_flight_hours"),
)

4.1 Modifying columns in select

The select() statement redefines the schema of a dataframe, so it naturally supports the inclusion or exclusion of columns, old and new, as well as the redefinition of pre-existing ones. By centralizing all such operations in a single statement, it becomes much easier to identify the final schema, which makes code more readable. It also makes code more concise.

Instead of calling withColumnRenamed() -> use alias():

#bad
df.select('key', 'comments').withColumnRenamed('comments', 'num_comments')

# good
df.select("key", F.col("comments").alias("num_comments"))

Instead of using withColumn() to redefine type -> cast() in the select:

# bad
df.select('comments').withColumn('comments', F.col('comments').cast('double'))

# good
df.select(F.col("comments").cast("double"))

4.2 Inclusive selection above exclusive

Only select columns that are needed, avoid using drop() as over time it might will return a different dataframe than expected, and it will be your job of fixing it.

Finally, instead of adding new columns via the select statement, using .withColumn() is recommended when adding a single column. When adding or manipulating tens or hundreds of columns, use a single .select() for performance reasons.

5 Empty columns

If you need to add an empty column to satisfy a schema, always use F.lit(None) for populating that column. Never use an empty string or some other string signaling an empty value (such as -,, or NA this can be interpreted as ‘North America’ for example).

Beyond being technically correct (the best kind of correct), one practical reason for using F.lit(None) is preserving the ability to use utilities like isNull, instead of having to verify empty strings, nulls, and 'NA', etc. Your future self will thank you.

# bad
df = df.withColumn('foo', F.lit(''))

# also bad
df = df.withColumn('foo', F.lit('NA'))

# good
df = df.withColumn("foo", F.lit(None))

6 Using comments

While comments can provide useful insight into code, it is often more valuable to refactor the code to improve its readability and only use comments to explain the why or provide more context. Most if not all, people looking at the code will be able to understand what the code is doing but not why it is doing it.

# bad

# Cast the timestamp columns
cols = ['start_date', 'delivery_date']
for c in cols:
    df = df.withColumn(c, F.from_unixtime(F.col(c) / 1000).cast(T.TimestampType()))

In the example above, we can see that those columns are getting cast to Timestamp. The comment doesn’t add value. Moreover, a more verbose comment might still be unhelpful if it only provides information that already exists in the code. For example:

# still bad

# Go through each column, divide by 1000 because millis and cast to timestamp
cols = ['start_date', 'delivery_date']
for c in cols:
    df = df.withColumn(c, F.from_unixtime(F.col(c) / 1000).cast(T.TimestampType()))

Instead of leaving comments that only describe the logic you wrote, aim to leave comments that give context and explain the why of decisions you made when writing the code. This is particularly important for PySpark since the reader can understand your code but frequently doesn’t have context on the data that feeds into your PySpark transform. Small pieces of logic might have involved hours of digging👷 through data to understand the correct behavior, in which case comments explaining the rationale are especially valuable.

# good

# The consumer of this dataset expects a timestamp instead of a date, and we need
# to adjust the time by 1000 because the original datasource is storing these as millis
# even though the documentation says it's actually a date.
cols = ["start_date", "delivery_date"]
for c in cols:
    df = df.withColumn(c, F.from_unixtime(F.col(c) / 1000).cast(T.TimestampType()))

7 UDFs (user defined functions)

It is highly recommended to avoid UDFs in all situations, as they are dramatically less performant than native PySpark. In most situations, logic that seems to necessitate a UDF can be refactored to use only native PySpark functions. However, there are some situations where a UDF is unavoidable.

8 Joins

Be careful with joins! If you perform a left join, and the right side has multiple matches for a key, that row will be duplicated as many times as there are matches. This is called a “join explosion” and can dramatically bloat the output of your transform job. Always double-check your assumptions to see that the key you are joining on is unique, unless you are expecting the multiplication.

Bad joins are the source of many tricky-to-debug issues. Always pass the join type by name, even if you are using the default values, such as (inner). You did know that inner is the default, right?

# bad
flights = flights.join(aircraft, 'aircraft_id')

# also bad
flights = flights.join(aircraft, 'aircraft_id', 'inner')

# good
flights = flights.join(other=aircraft, on="aircraft_id", how="inner")

Avoid right joins. If you are about to use a right join, switch the order of your dataframes and use a left join instead. It is more intuitive since the dataframe you are doing the operation on is the one that you are centering your join around. This one doesn’t make sense, but it does.

# bad
flights = aircraft.join(flights, 'aircraft_id', how='right')

# good
flights = flights.join(other=aircraft, on="aircraft_id", how="left")

8.1 Column collisions when joining

Avoid renaming all columns to avoid collisions. Instead, give an alias to the whole dataframe, and use that alias to select which columns you want in the end.

# bad
columns = ['start_time', 'end_time', 'idle_time', 'total_time']
for col in columns:
    flights = flights.withColumnRenamed(col, 'flights_' + col)
    parking = parking.withColumnRenamed(col, 'parking_' + col)

flights = flights.join(parking, on='flight_code', how='left')

flights = flights.select(
    F.col('flights_start_time').alias('flight_start_time'),
    F.col('flights_end_time').alias('flight_end_time'),
    F.col('parking_total_time').alias('client_parking_total_time')
)

# good
flights = flights.alias("flights")
parking = parking.alias("parking")

flights = flights.join(other=parking, on="flight_code", how="left")

flights = flights.select(
    F.col("flights.start_time").alias("flight_start_time"),
    F.col("flights.end_time").alias("flight_end_time"),
    F.col("parking.total_time").alias("client_parking_total_time"),
)

In such cases, keep in mind:

  • It is a better idea to only select the columns that are needed before joining
  • In case you do need both, it might be best to rename one of them prior to joining, signaling the difference between the two columns, as most likely the underlying data generating process is different.
  • You should always resolve ambiguous columns before outputting a dataset. After the transform is finished you can no longer distinguish between cols.

8.2 .dropDuplicates() and .distinct() to “clean” joins

Don’t think about using .dropDuplicates() or .distinct() as a quick fix for data duplication after a join. If unexpected duplicate rows are in your dataframe, there is always an underlying reason for why those duplicate rows appear. Adding .dropDuplicates() only masks this problem and adds unnecessary CPU cycles.

9 Window Functions

Window functions are incredibly useful when performing data transformations, especially when you need to perform calculations that require a specific order or grouping of data. They allow you to perform operations across a set of table rows that are somehow related to the current row. This is particularly powerful for operations like running totals, moving averages, and rank calculations.

Always explicitly define three things when working with window functions:

  • partitionBy: The partitions or groups over which the window function will be applied.
  • orderBy: The ordering of rows within the partition.
  • rowsBetween or rangeBetween: Defines the scope in rows or ranges considered when applying a function over the ordered partition.

By specifying these three components, you ensure that your windows are properly defined.

from pyspark.sql import functions as F, Window as W
df = spark.createDataFrame(
    [("a", 3), ("a", 4), ("a", 1), ("a", 2)],
    ["key", "num"], 
)

# bad
w1 = W.partitionBy('key')

# better
w2 = W.partitionBy('key').orderBy('num')

df.select("key", "num", F.sum("num").over(w1).alias("sum")).show()
# +---+---+---+
# |key|num|sum|
# +---+---+---+
# |  a|  3| 10|
# |  a|  4| 10|
# |  a|  1| 10|
# |  a|  2| 10|
# +---+---+---+
df.select("key", "num", F.sum("num").over(w2).alias("sum")).show()
# +---+---+---+
# |key|num|sum|
# +---+---+---+
# |  a|  1|  1|
# |  a|  2|  3|
# |  a|  3|  6|
# |  a|  4| 10|
# +---+---+---+

When w2 is used, implicitly you have defined a rolling window specification, you are including all the rows in the window. For the sum function that might not be a problem, however when other functions are used such as first or last the problem becomes more apparent.


df.select("key", "num", F.last("num").over(w1).alias("last")).show()
# +---+---+----+
# |key|num|last|
# +---+---+----+
# |  a|  3|   2|
# |  a|  4|   2|
# |  a|  1|   2|
# |  a|  2|   2|
# +---+---+----+

df.select("key", "num", F.last("num").over(w2).alias("last")).show()
# +---+---+----+
# |key|num|last|
# +---+---+----+
# |  a|  1|   1|
# |  a|  2|   2|
# |  a|  3|   3|
# |  a|  4|   4|
# +---+---+----+

Here the result is dependent on the order of the rows within the dataframe, this might change if the dataframe is shuffled and the rows are retrieved in a different order. Making the results of your transformations non-deterministic.

It is a good idea to always specify an explicit window with the three components: partitionBy, orderBy and rowsBetween/rangeBetween. That way you can guarantee that the window behaves in a deterministic way.

#good
w3 = (
    W.partitionBy("key")
    .orderBy("num")
    .rowsBetween(
        start=W.unboundedPreceding,
        end=0, #<- zero means the current row, -1 is the row before the current
    )
)
#good
w4 = (
    W.partitionBy("key")
    .orderBy("num")
    .rowsBetween(
        start=W.unboundedPreceding,
        end=W.unboundedFollowing,
    )
)

df.select("key", "num",F.sum("num").over(w3).alias("sum")).show()
# +---+---+---+
# |key|num|sum|
# +---+---+---+
# |  a|  1|  1|
# |  a|  2|  3|
# |  a|  3|  6|
# |  a|  4| 10|
# +---+---+---+


df.select("key", "num",F.sum("num").over(w4).alias("sum")).show()
# +---+---+---+
# |key|num|sum|
# +---+---+---+
# |  a|  1| 10|
# |  a|  2| 10|
# |  a|  3| 10|
# |  a|  4| 10|
# +---+---+---+

Now the window is fully defined, with w3 being an expanding window and w4 containing the entire partition. Even if the returned rows are in a different order, the outcome of the window function remains consistent and deterministic.


df.select("key", "num",F.last("num").over(w3).alias("last")).show()
# +---+---+----+
# |key|num|last|
# +---+---+----+
# |  a|  1|   1|
# |  a|  2|   2|
# |  a|  3|   3|
# |  a|  4|   4|
# +---+---+----+


df.select("key", "num",F.last("num").over(w4).alias("last")).show()
# +---+---+----+
# |key|num|last|
# +---+---+----+
# |  a|  1|   4|
# |  a|  2|   4|
# |  a|  3|   4|
# |  a|  4|   4|
# +---+---+----+

Here we can see that the results are as we expect them to be. The results are deterministic and the order of the rows is not influencing the results.

9.1 Dealing with nulls

While nulls are ignored for aggregate functions (like F.sum() and F.max()), they will impact the result of analytic functions (such as F.first()/F.last() and F.rank()).

df_nulls = spark.createDataFrame(
    [("a", None), ("a", 2), ("a", 1), ("a", 3), ("a", None)],
    ["key", "num"],
)


df_nulls.select("key", "num", F.first("num").over(w4).alias("first")).show()
# |key| num|first|
# +---+----+-----+
# |  a|NULL| NULL|
# |  a|NULL| NULL|
# |  a|   1| NULL|
# |  a|   2| NULL|
# |  a|   3| NULL|
# +---+----+-----+

Depending on the use case this might be the desired behavior, but it might not always be applicable. It is best to avoid this situation by explicitly setting the sort order and the nulls behavior.

# sorting ascending and nulls last
w5 = (
    W.partitionBy("key")
    .orderBy(F.asc_nulls_last("num"))  
    .rowsBetween(
        start=W.unboundedPreceding,
        end=W.unboundedFollowing,
    )
)

df_nulls.select("key", "num", F.first("num").over(w5).alias("first")).show()
# +---+----+-----+
# |key| num|first|
# +---+----+-----+
# |  a|   1|    1|
# |  a|   2|    1|
# |  a|   3|    1|
# |  a|NULL|    1|
# |  a|NULL|    1|
# +---+----+-----+

9.2 Empty partitionBy()

Spark window functions can be applied over all rows, using a global frame. This is accomplished by specifying zero columns in the partition by expression (i.e. W.partitionBy()). Code like this should be avoided, however, as it forces Spark to combine all data into a single partition, which can be extremely harmful for performance. Prefer to use aggregations whenever possible:

# bad
w = W.partitionBy()
df.select(F.sum('num').over(w).alias('sum'))
# WARN WindowExec: No Partition Defined for Window operation! 
# Moving all data to a single partition, this can cause serious performance degradation


# good
df.agg(F.sum("num").alias("sum"))

10 Other Considerations and Recommendations

  1. Be cautious of functions that grow too large. Generally, a file should not exceed 250 lines, and a function should not exceed 70 lines.
  2. Organize your code into logical blocks. For example, if you have multiple lines referencing the same elements, group them together. Separating them diminishes context and readability.
  3. Test your code! If you can run the local tests, do so and make sure that your new code is covered by the tests. If you can’t run the local tests, build the datasets on your branch and manually verify that the data looks as expected.
  4. Avoid using .otherwise(value) as a general fallback. If you are mapping a list of keys to a list of values and encounter unknown keys, using otherwise will mask all of these into a single value.
  5. When encountering a large, single transformation that integrates multiple different source tables, break it into sub-steps and extract the logic into functions. This enhances higher-level readability and promotes code reusability and consistency between transformations.
  6. Be as explicit and descriptive as possible when naming functions or variables. Aim to capture the function’s or variable’s purpose rather than naming it based on the objects it uses.
  7. Avoid using literal strings or integers in filtering conditions, new column values, etc. Instead, to capture their meaning, extract them into variables, constants, dictionaries, or classes as appropriate. This enhances code readability and enforces consistency across the repository.