How to Use Pivot in PySpark: Syntax and Example
In PySpark, use the
pivot() function on a grouped DataFrame to convert unique values from one column into multiple columns. This is combined with an aggregation function like sum() or count() to summarize data in the new columns.Syntax
The pivot() function is used after grouping a DataFrame with groupBy(). It takes the column name whose unique values become new columns. Then, an aggregation function summarizes the data.
- groupBy(cols): Groups data by one or more columns.
- pivot(pivot_col): Turns unique values in
pivot_colinto columns. - agg(aggregation): Aggregates values for each new column.
python
df.groupBy('group_col').pivot('pivot_col').agg({'value_col': 'sum'})
Example
This example shows how to pivot a DataFrame of sales data to get total sales per product for each store.
python
from pyspark.sql import SparkSession from pyspark.sql.functions import sum spark = SparkSession.builder.appName('PivotExample').getOrCreate() data = [ ('Store1', 'Apples', 10), ('Store1', 'Bananas', 20), ('Store2', 'Apples', 15), ('Store2', 'Bananas', 5), ('Store1', 'Oranges', 7), ('Store2', 'Oranges', 10) ] columns = ['store', 'product', 'sales'] df = spark.createDataFrame(data, columns) pivot_df = df.groupBy('store').pivot('product').agg(sum('sales')) pivot_df.show()
Output
+------+-------+-------+-------+
| store| Apples|Bananas|Oranges|
+------+-------+-------+-------+
|Store1| 10| 20| 7|
|Store2| 15| 5| 10|
+------+-------+-------+-------+
Common Pitfalls
Common mistakes when using pivot() include:
- Not grouping the DataFrame before pivoting causes errors.
- Using
pivot()without an aggregation function leads to incomplete results. - Pivoting on columns with many unique values can cause performance issues.
- Missing values after pivot appear as
null, which may need filling.
python
from pyspark.sql import SparkSession spark = SparkSession.builder.appName('PivotPitfall').getOrCreate() data = [('A', 'X', 1), ('A', 'Y', 2)] columns = ['group', 'category', 'value'] df = spark.createDataFrame(data, columns) # Wrong: pivot without groupBy # df.pivot('category').sum('value').show() # This will error # Right: groupBy before pivot pivot_df = df.groupBy('group').pivot('category').sum('value') pivot_df.show()
Output
+-----+----+----+
|group| X| Y|
+-----+----+----+
| A| 1| 2|
+-----+----+----+
Key Takeaways
Always use groupBy before pivot to prepare data for reshaping.
Pivot turns unique values in one column into new columns with aggregated data.
Use an aggregation function like sum or count after pivot to summarize values.
Be cautious with columns having many unique values to avoid performance issues.
Nulls after pivot mean missing data and may need handling.