Sometimes when joining two big tables, some keys appear way more than others. This makes the join slow and uneven. Handling skewed joins helps balance the work and speed up the process.
Handling skewed joins in Apache Spark
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1) # Example of salting technique from pyspark.sql.functions import rand # Add a random salt to the skewed key (large table) skewed_df = df.withColumn("salt", (rand() * 10).cast("int")) # Create salt values for small table salts = spark.range(10).withColumnRenamed("id", "salt") # Cross join small table with all salts small_df = other_df.crossJoin(salts) # Join on key and salt joined_df = skewed_df.join(small_df, on=["key", "salt"])
Setting spark.sql.autoBroadcastJoinThreshold to -1 disables broadcast joins to force shuffle joins.
Salting adds a random number to the join key to spread out skewed keys across partitions.
# Disable broadcast join spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
# Add salt column to skewed dataframe from pyspark.sql.functions import rand skewed_df = df.withColumn("salt", (rand() * 10).cast("int"))
# Cross join non-skewed df with salt values salts = spark.range(10).withColumnRenamed("id", "salt") non_skewed_df = other_df.crossJoin(salts) # Join on key and salt joined_df = skewed_df.join(non_skewed_df, on=["key", "salt"])
This program creates a skewed dataset where key=1 appears 1000 times. It adds a salt column to spread the join load. The smaller dataframe is expanded with all salt values to match. The join is done on both key and salt to balance the work.
from pyspark.sql import SparkSession from pyspark.sql.functions import rand, expr spark = SparkSession.builder.appName("SkewedJoinExample").getOrCreate() # Create skewed dataframe with many rows for key=1 data1 = [(1, f"val_{i}") for i in range(1000)] + [(2, "val_1000")] # Create smaller dataframe data2 = [(1, "info_1"), (2, "info_2")] # Create DataFrames skewed_df = spark.createDataFrame(data1, ["key", "value"]) other_df = spark.createDataFrame(data2, ["key", "info"]) # Disable broadcast join to force shuffle join spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1) # Add salt column to skewed_df skewed_df = skewed_df.withColumn("salt", (rand() * 10).cast("int")) # Add salt column to other_df with fixed salt values (0 to 9) to join properly salt_values = spark.range(10).withColumnRenamed("id", "salt") other_df_salted = other_df.crossJoin(salt_values) # Join on key and salt joined_df = skewed_df.join(other_df_salted, on=["key", "salt"]) # Show count to verify join print(f"Joined rows count: {joined_df.count()}") spark.stop()
Salting works best when you know which keys are skewed.
Too many salt values increase data size and join time, so choose wisely.
Disabling broadcast join helps avoid uneven data distribution in skewed joins.
Skewed joins happen when some keys have many rows, causing slow joins.
Salting adds a random number to keys to spread data evenly across partitions.
Disabling broadcast joins can help force better join strategies for skewed data.