0
0
Apache Sparkdata~7 mins

Handling skewed joins in Apache Spark

Choose your learning style9 modes available
Introduction

Sometimes when joining two big tables, some keys appear way more than others. This makes the join slow and uneven. Handling skewed joins helps balance the work and speed up the process.

Joining customer data where a few customers have many transactions
Combining sales data where some products are extremely popular
Merging logs where some IP addresses appear very often
Joining user activity data where a few users are very active
Syntax
Apache Spark
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

# Example of salting technique
from pyspark.sql.functions import rand

# Add a random salt to the skewed key (large table)
skewed_df = df.withColumn("salt", (rand() * 10).cast("int"))

# Create salt values for small table
salts = spark.range(10).withColumnRenamed("id", "salt")

# Cross join small table with all salts
small_df = other_df.crossJoin(salts)

# Join on key and salt
joined_df = skewed_df.join(small_df, on=["key", "salt"])

Setting spark.sql.autoBroadcastJoinThreshold to -1 disables broadcast joins to force shuffle joins.

Salting adds a random number to the join key to spread out skewed keys across partitions.

Examples
This forces Spark to avoid broadcast joins which can cause skew issues.
Apache Spark
# Disable broadcast join
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
This creates a new column with random integers from 0 to 9 to distribute skewed keys.
Apache Spark
# Add salt column to skewed dataframe
from pyspark.sql.functions import rand
skewed_df = df.withColumn("salt", (rand() * 10).cast("int"))
Replicate non-skewed rows for each possible salt and join on both key and salt to balance load without data loss.
Apache Spark
# Cross join non-skewed df with salt values
salts = spark.range(10).withColumnRenamed("id", "salt")
non_skewed_df = other_df.crossJoin(salts)

# Join on key and salt
joined_df = skewed_df.join(non_skewed_df, on=["key", "salt"])
Sample Program

This program creates a skewed dataset where key=1 appears 1000 times. It adds a salt column to spread the join load. The smaller dataframe is expanded with all salt values to match. The join is done on both key and salt to balance the work.

Apache Spark
from pyspark.sql import SparkSession
from pyspark.sql.functions import rand, expr

spark = SparkSession.builder.appName("SkewedJoinExample").getOrCreate()

# Create skewed dataframe with many rows for key=1
data1 = [(1, f"val_{i}") for i in range(1000)] + [(2, "val_1000")]

# Create smaller dataframe
data2 = [(1, "info_1"), (2, "info_2")]

# Create DataFrames
skewed_df = spark.createDataFrame(data1, ["key", "value"])
other_df = spark.createDataFrame(data2, ["key", "info"])

# Disable broadcast join to force shuffle join
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

# Add salt column to skewed_df
skewed_df = skewed_df.withColumn("salt", (rand() * 10).cast("int"))

# Add salt column to other_df with fixed salt values (0 to 9) to join properly
salt_values = spark.range(10).withColumnRenamed("id", "salt")
other_df_salted = other_df.crossJoin(salt_values)

# Join on key and salt
joined_df = skewed_df.join(other_df_salted, on=["key", "salt"])

# Show count to verify join
print(f"Joined rows count: {joined_df.count()}")

spark.stop()
OutputSuccess
Important Notes

Salting works best when you know which keys are skewed.

Too many salt values increase data size and join time, so choose wisely.

Disabling broadcast join helps avoid uneven data distribution in skewed joins.

Summary

Skewed joins happen when some keys have many rows, causing slow joins.

Salting adds a random number to keys to spread data evenly across partitions.

Disabling broadcast joins can help force better join strategies for skewed data.